Coverage for hadal/torch_search.py: 0%
18 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-07 20:48 +0300
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-07 20:48 +0300
1"""Test."""
2from __future__ import annotations
4import logging
6import torch
8from hadal.custom_logger import default_custom_logger
11class TorchSearch:
12 def __init__(
13 self, device: str | None = None, *, enable_logging: bool = True, log_level: int | None = logging.INFO
14 ) -> None:
15 if enable_logging is True:
16 self.logger = default_custom_logger(name=__name__, level=log_level)
17 else:
18 self.logger = logging.getLogger(__name__)
19 self.logger.disabled = True
21 if device is None:
22 device = "cuda" if torch.cuda.is_available() else "cpu"
23 self.logger.info("Pytorch Search device: %s", device)
25 self._target_device = torch.device(device)
27 def _inner_product(self, source_embeddings: torch.Tensor, target_embeddings: torch.Tensor) -> torch.Tensor:
28 return torch.matmul
30 def k_nearest_neighbors():
31 pass