Coverage for hadal/faiss_search.py: 76%

33 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-09 22:55 +0300

1"""This module contains the `class FaissSearch` that can be used to perform k-nearest neighbor search using the Faiss library.""" 

2from __future__ import annotations 

3 

4import logging 

5from typing import TYPE_CHECKING 

6 

7import faiss 

8 

9from hadal.custom_logger import default_custom_logger 

10 

11if TYPE_CHECKING: 

12 import numpy 

13 

14 

15class FaissSearch: 

16 """Class to perform k-nearest neighbor search using the Faiss library. 

17 

18 Methods: 

19 k_nearest_neighbors: Perform k-nearest neighbor search using Faiss. 

20 """ 

21 

22 def __init__(self, device: str | None = None, *, enable_logging: bool = True, log_level: int | None = logging.INFO) -> None: 

23 """Initialize FaissSearch object. 

24 

25 Args: 

26 device (str | None, optional): Device for the Faiss search. If `None`, it will use GPU if available, otherwise CPU. Default is `None`. 

27 enable_logging (bool, optional): Logging option. 

28 log_level (int | None, optional): Logging level. 

29 """ 

30 if enable_logging is True: 

31 self.logger = default_custom_logger(name=__name__, level=log_level) 

32 else: 

33 self.logger = logging.getLogger(__name__) 

34 self.logger.disabled = True 

35 

36 if device is None: 

37 device = "cuda" if faiss.get_num_gpus() > 0 else "cpu" 

38 self.logger.info("Faiss device: %s", device) 

39 

40 self._target_device = device 

41 

42 def k_nearest_neighbors( 

43 self, 

44 source_embeddings: numpy.ndarray, 

45 target_embeddings: numpy.ndarray, 

46 k: int = 4, 

47 knn_metric: str = "inner_product", 

48 device: str | None = None, 

49 ) -> tuple[numpy.ndarray, numpy.ndarray]: 

50 """Perform k-nearest neighbor search using Faiss. 

51 

52 Args: 

53 source_embeddings (numpy.ndarray): The source embeddings. 

54 target_embeddings (numpy.ndarray): The target embeddings. 

55 k (int, optional): The number of nearest neighbors. 

56 knn_metric (str, optional): The metric to use for k-nearest neighbor search. Can be `inner_product` or `sqeuclidean`. 

57 device (str | None, optional): The device to use for Faiss search. If `None`, it will use GPU if available, otherwise CPU. 

58 

59 Note: 

60 It is fully relying on the Faiss library for the k-nearest neighbor search `faiss.knn` and `faiss.gpu_knn`. 

61 

62 - `inner_product` uses `faiss.METRIC_INNER_PRODUCT` 

63 - `sqeuclidean` uses `faiss.METRIC_L2` (squared Euclidean distance) 

64 

65 Returns: 

66 - d (numpy.ndarray): The distances of the k-nearest neighbors. 

67 - ind (numpy.ndarray): The indices of the k-nearest neighbors. 

68 """ 

69 if device is None: 

70 device = self._target_device 

71 

72 self.logger.info("Perform k-nearest neighbor search...") 

73 

74 if knn_metric == "inner_product": 

75 knn_metric = faiss.METRIC_INNER_PRODUCT 

76 elif knn_metric == "sqeuclidean": 

77 # squared Euclidean (L2) distance 

78 knn_metric = faiss.METRIC_L2 

79 

80 if device == "cpu": 

81 self.logger.info("Using faiss knn on CPU...") 

82 # https://github.com/facebookresearch/faiss/blob/d85601d972af2d64103769ab8d940db28aaae2a0/faiss/python/extra_wrappers.py#L330 

83 d, ind = faiss.knn(xq=source_embeddings, xb=target_embeddings, k=k, metric=knn_metric) 

84 else: 

85 self.logger.info("Using faiss knn on GPU...") 

86 res = faiss.StandardGpuResources() 

87 d, ind = faiss.gpu_knn(res=res, xq=source_embeddings, xb=target_embeddings, k=k, metric_type=knn_metric) 

88 

89 self.logger.info("Done k-nearest neighbor search!") 

90 return d, ind