Coverage for hadal/faiss_search.py: 76%
33 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
1"""This module contains the `class FaissSearch` that can be used to perform k-nearest neighbor search using the Faiss library."""
2from __future__ import annotations
4import logging
5from typing import TYPE_CHECKING
7import faiss
9from hadal.custom_logger import default_custom_logger
11if TYPE_CHECKING:
12 import numpy
15class FaissSearch:
16 """Class to perform k-nearest neighbor search using the Faiss library.
18 Methods:
19 k_nearest_neighbors: Perform k-nearest neighbor search using Faiss.
20 """
22 def __init__(self, device: str | None = None, *, enable_logging: bool = True, log_level: int | None = logging.INFO) -> None:
23 """Initialize FaissSearch object.
25 Args:
26 device (str | None, optional): Device for the Faiss search. If `None`, it will use GPU if available, otherwise CPU. Default is `None`.
27 enable_logging (bool, optional): Logging option.
28 log_level (int | None, optional): Logging level.
29 """
30 if enable_logging is True:
31 self.logger = default_custom_logger(name=__name__, level=log_level)
32 else:
33 self.logger = logging.getLogger(__name__)
34 self.logger.disabled = True
36 if device is None:
37 device = "cuda" if faiss.get_num_gpus() > 0 else "cpu"
38 self.logger.info("Faiss device: %s", device)
40 self._target_device = device
42 def k_nearest_neighbors(
43 self,
44 source_embeddings: numpy.ndarray,
45 target_embeddings: numpy.ndarray,
46 k: int = 4,
47 knn_metric: str = "inner_product",
48 device: str | None = None,
49 ) -> tuple[numpy.ndarray, numpy.ndarray]:
50 """Perform k-nearest neighbor search using Faiss.
52 Args:
53 source_embeddings (numpy.ndarray): The source embeddings.
54 target_embeddings (numpy.ndarray): The target embeddings.
55 k (int, optional): The number of nearest neighbors.
56 knn_metric (str, optional): The metric to use for k-nearest neighbor search. Can be `inner_product` or `sqeuclidean`.
57 device (str | None, optional): The device to use for Faiss search. If `None`, it will use GPU if available, otherwise CPU.
59 Note:
60 It is fully relying on the Faiss library for the k-nearest neighbor search `faiss.knn` and `faiss.gpu_knn`.
62 - `inner_product` uses `faiss.METRIC_INNER_PRODUCT`
63 - `sqeuclidean` uses `faiss.METRIC_L2` (squared Euclidean distance)
65 Returns:
66 - d (numpy.ndarray): The distances of the k-nearest neighbors.
67 - ind (numpy.ndarray): The indices of the k-nearest neighbors.
68 """
69 if device is None:
70 device = self._target_device
72 self.logger.info("Perform k-nearest neighbor search...")
74 if knn_metric == "inner_product":
75 knn_metric = faiss.METRIC_INNER_PRODUCT
76 elif knn_metric == "sqeuclidean":
77 # squared Euclidean (L2) distance
78 knn_metric = faiss.METRIC_L2
80 if device == "cpu":
81 self.logger.info("Using faiss knn on CPU...")
82 # https://github.com/facebookresearch/faiss/blob/d85601d972af2d64103769ab8d940db28aaae2a0/faiss/python/extra_wrappers.py#L330
83 d, ind = faiss.knn(xq=source_embeddings, xb=target_embeddings, k=k, metric=knn_metric)
84 else:
85 self.logger.info("Using faiss knn on GPU...")
86 res = faiss.StandardGpuResources()
87 d, ind = faiss.gpu_knn(res=res, xq=source_embeddings, xb=target_embeddings, k=k, metric_type=knn_metric)
89 self.logger.info("Done k-nearest neighbor search!")
90 return d, ind