Coverage for hadal/torch_search.py: 0%

18 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-07 20:48 +0300

1"""Test.""" 

2from __future__ import annotations 

3 

4import logging 

5 

6import torch 

7 

8from hadal.custom_logger import default_custom_logger 

9 

10 

11class TorchSearch: 

12 def __init__( 

13 self, device: str | None = None, *, enable_logging: bool = True, log_level: int | None = logging.INFO 

14 ) -> None: 

15 if enable_logging is True: 

16 self.logger = default_custom_logger(name=__name__, level=log_level) 

17 else: 

18 self.logger = logging.getLogger(__name__) 

19 self.logger.disabled = True 

20 

21 if device is None: 

22 device = "cuda" if torch.cuda.is_available() else "cpu" 

23 self.logger.info("Pytorch Search device: %s", device) 

24 

25 self._target_device = torch.device(device) 

26 

27 def _inner_product(self, source_embeddings: torch.Tensor, target_embeddings: torch.Tensor) -> torch.Tensor: 

28 return torch.matmul 

29 

30 def k_nearest_neighbors(): 

31 pass