Source code for skhubness.neighbors.random_projection_trees

# -*- coding: utf-8 -*-
# SPDX-License-Identifier: BSD-3-Clause
# Author: Tom Dupre la Tour (original work)
#         Roman Feldbauer (adaptions for scikit-hubness)
# PEP 563: Postponed Evaluation of Annotations
from __future__ import annotations

import logging
import sys
from typing import Union, Tuple

try:
    import annoy
except ImportError:
    print("The package 'annoy' is required to run this example.")  # pragma: no cover
    sys.exit()  # pragma: no cover

import numpy as np

from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from tqdm.auto import tqdm
from .approximate_neighbors import ApproximateNearestNeighbor
from ..utils.check import check_n_candidates
from ..utils.io import create_tempfile_preferably_in_dir

print(__doc__)

__all__ = ['RandomProjectionTree',
           ]


[docs]class RandomProjectionTree(BaseEstimator, ApproximateNearestNeighbor): """Wrapper for using annoy.AnnoyIndex Annoy is an approximate nearest neighbor library, that builds a forest of random projections trees. Parameters ---------- n_candidates: int, default = 5 Number of neighbors to retrieve metric: str, default = 'euclidean' Distance metric, allowed are "angular", "euclidean", "manhattan", "hamming", "dot" n_trees: int, default = 10 Build a forest of n_trees trees. More trees gives higher precision when querying, but are more expensive in terms of build time and index size. search_k: int, default = -1 Query will inspect search_k nodes. A larger value will give more accurate results, but will take longer time. mmap_dir: str, default = 'auto' Memory-map the index to the given directory. This is required to make the the class pickleable. If None, keep everything in main memory (NON pickleable index), if mmap_dir is a string, it is interpreted as a directory to store the index into, if 'auto', create a temp dir for the index, preferably in /dev/shm on Linux. n_jobs: int, default = 1 Number of parallel jobs verbose: int, default = 0 Verbosity level. If verbose > 0, show tqdm progress bar on indexing and querying. Attributes ---------- valid_metrics: List of valid distance metrics/measures """ valid_metrics = ["angular", "euclidean", "manhattan", "hamming", "dot", "minkowski"]
[docs] def __init__(self, n_candidates: int = 5, metric: str = 'euclidean', n_trees: int = 10, search_k: int = -1, mmap_dir: str = 'auto', n_jobs: int = 1, verbose: int = 0): super().__init__(n_candidates=n_candidates, metric=metric, n_jobs=n_jobs, verbose=verbose, ) self.n_trees = n_trees self.search_k = search_k self.mmap_dir = mmap_dir
[docs] def fit(self, X, y=None) -> RandomProjectionTree: """ Build the annoy.Index and insert data from X. Parameters ---------- X: np.array Data to be indexed y: any Ignored Returns ------- self: RandomProjectionTree An instance of RandomProjectionTree with a built index """ if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.y_train_ = y self.n_samples_fit_ = X.shape[0] self.n_features_ = X.shape[1] self.X_dtype_ = X.dtype if self.metric == 'minkowski': # for compatibility self.metric = 'euclidean' metric = self.metric if self.metric != 'sqeuclidean' else 'euclidean' self.effective_metric_ = metric annoy_index = annoy.AnnoyIndex(X.shape[1], metric=metric) if self.mmap_dir == 'auto': self.annoy_ = create_tempfile_preferably_in_dir(prefix='skhubness_', suffix='.annoy', directory='/dev/shm') logging.warning(f'The index will be stored in {self.annoy_}. ' f'It will NOT be deleted automatically, when this instance is destructed.') elif isinstance(self.mmap_dir, str): self.annoy_ = create_tempfile_preferably_in_dir(prefix='skhubness_', suffix='.annoy', directory=self.mmap_dir) else: # e.g. None self.mmap_dir = None for i, x in tqdm(enumerate(X), desc='Build RPtree', disable=False if self.verbose else True, ): annoy_index.add_item(i, x.tolist()) annoy_index.build(self.n_trees) if self.mmap_dir is None: self.annoy_ = annoy_index else: annoy_index.save(self.annoy_, ) return self
[docs] def kneighbors(self, X=None, n_candidates=None, return_distance=True) -> Union[Tuple[np.array, np.array], np.array]: """ Retrieve k nearest neighbors. Parameters ---------- X: np.array or None, optional, default = None Query objects. If None, search among the indexed objects. n_candidates: int or None, optional, default = None Number of neighbors to retrieve. If None, use the value passed during construction. return_distance: bool, default = True If return_distance, will return distances and indices to neighbors. Else, only return the indices. """ check_is_fitted(self, 'annoy_') if X is not None: X = check_array(X) n_test = self.n_samples_fit_ if X is None else X.shape[0] dtype = self.X_dtype_ if X is None else X.dtype if n_candidates is None: n_candidates = self.n_candidates n_candidates = check_n_candidates(n_candidates) # For compatibility reasons, as each sample is considered as its own # neighbor, one extra neighbor will be computed. if X is None: n_neighbors = n_candidates + 1 start = 1 else: n_neighbors = n_candidates start = 0 # If fewer candidates than required are found for a query, # we save index=-1 and distance=NaN neigh_ind = -np.ones((n_test, n_candidates), dtype=np.int32) neigh_dist = np.empty_like(neigh_ind, dtype=dtype) * np.nan # Load memory-mapped annoy.Index, unless it's already in main memory if isinstance(self.annoy_, str): annoy_index = annoy.AnnoyIndex(self.n_features_, metric=self.effective_metric_) annoy_index.load(self.annoy_) elif isinstance(self.annoy_, annoy.AnnoyIndex): annoy_index = self.annoy_ assert isinstance(annoy_index, annoy.AnnoyIndex), f'Internal error: unexpected type for annoy index' disable_tqdm = False if self.verbose else True if X is None: n_items = annoy_index.get_n_items() for i in tqdm(range(n_items), desc='Query RPtree', disable=disable_tqdm, ): ind, dist = annoy_index.get_nns_by_item( i, n_neighbors, self.search_k, include_distances=True, ) ind = ind[start:] dist = dist[start:] neigh_ind[i, :len(ind)] = ind neigh_dist[i, :len(dist)] = dist else: # if X was provided for i, x in tqdm(enumerate(X), desc='Query RPtree', disable=disable_tqdm, ): ind, dist = annoy_index.get_nns_by_vector( x.tolist(), n_neighbors, self.search_k, include_distances=True, ) ind = ind[start:] dist = dist[start:] neigh_ind[i, :len(ind)] = ind neigh_dist[i, :len(dist)] = dist if self.metric == 'sqeuclidean': neigh_dist **= 2 if return_distance: return neigh_dist, neigh_ind else: return neigh_ind