Coverage for hadal/parallel_sentence_mining/margin_based/margin_based.py: 91%
43 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
1"""This module contains the `class MarginBasedPipeline` that implements the margin-based pipeline."""
2from __future__ import annotations
4import logging
5from typing import TYPE_CHECKING
7from hadal.custom_logger import default_custom_logger
8from hadal.faiss_search import FaissSearch
9from hadal.huggingface_automodel import HuggingfaceAutoModel
10from hadal.parallel_sentence_mining.margin_based.margin_based_tools import MarginBased
12if TYPE_CHECKING:
13 import pathlib
15 import numpy
18class MarginBasedPipeline:
19 """Class that implements the margin-based pipeline.
21 Methods:
22 make_alignments: Make sentence alignments.
23 """
25 def __init__(
26 self,
27 model_name_or_path: str | pathlib.Path,
28 model_device: str | None = None,
29 faiss_device: str | None = None,
30 *,
31 enable_logging: bool = True,
32 log_level: int | None = logging.INFO,
33 ) -> None:
34 """Initialize a MarginBasedPipeline object.
36 Args:
37 model_name_or_path (str | pathlib.Path): Name or path to the pre-trained model.
38 model_device (str | None, optional): Device for the model.
39 faiss_device (str | None, optional): Device for the Faiss search. If `None`, it will use GPU if available, otherwise CPU.
40 enable_logging (bool, optional): Logging option.
41 log_level (int | None, optional): Logging level.
42 """
43 self.model = HuggingfaceAutoModel(
44 model_name_or_path=model_name_or_path,
45 device=model_device,
46 enable_logging=enable_logging,
47 )
48 self.align_method = MarginBased()
49 self.faiss_search = FaissSearch(device=faiss_device, enable_logging=enable_logging)
50 self.model_device = model_device
51 self.faiss_device = faiss_device
53 if enable_logging is True:
54 self.logger = default_custom_logger(name=__name__, level=log_level)
55 else:
56 self.logger = logging.getLogger(__name__)
57 self.logger.disabled = True
59 def make_alignments(
60 self,
61 source_sentences: list[str],
62 target_sentences: list[str],
63 batch_size: int = 32,
64 output_value: str = "pooler_output",
65 convert_to: str = "numpy",
66 *,
67 normalize_embeddings: bool = True,
68 knn_neighbors: int = 4,
69 knn_metric: str = "inner_product",
70 margin: str = "ratio",
71 strategy: str = "max_score",
72 ) -> list[tuple[numpy.float64, str, str]]:
73 """Make sentence alignments.
75 Args:
76 source_sentences (list[str]): Source sentences.
77 target_sentences (list[str]): Target sentences.
78 batch_size (int, optional): The batch size.
79 output_value (str, optional): Model output type. Can be `pooler_output` or `last_hidden_state`.
80 convert_to (str, optional): Convert the embeddings to `torch` or `numpy` format. If `torch`, it will return a `torch.Tensor`. If `numpy`, it will return a `numpy.ndarray`. If `None`, it will return a `list[torch.Tensor]`.
81 normalize_embeddings (bool, optional): Normalize the embeddings.
82 knn_neighbors (int, optional): The number of nearest neighbors.
83 knn_metric (str, optional): The metric to use for k-nearest neighbor search. Can be `inner_product` or `l2`.
84 margin (str, optional): The margin function to use. Valid options are `ratio` and `distance`.
85 strategy (str, optional): The strategy to use for selecting the best candidates.
87 Returns:
88 bitext_list (list[tuple[numpy.float64, str, str]]): The `list[tuple[score, source_sentence, target_sentence]]` of the best sentence alignments.
89 """
90 self.logger.info("Encoding embeddings for source sentences...")
91 source_embeddings = self.model.encode(
92 sentences=source_sentences,
93 batch_size=batch_size,
94 output_value=output_value,
95 convert_to=convert_to,
96 normalize_embeddings=normalize_embeddings,
97 )
98 self.logger.info("Encoding embeddings for target sentences...")
99 target_embeddings = self.model.encode(
100 sentences=target_sentences,
101 batch_size=batch_size,
102 output_value=output_value,
103 convert_to=convert_to,
104 normalize_embeddings=normalize_embeddings,
105 )
107 self.logger.info("Perform kNN in both directions...")
108 self.logger.info("Perform kNN in source -> target direction")
109 x2y_sim, x2y_ind = self.faiss_search.k_nearest_neighbors(
110 source_embeddings,
111 target_embeddings,
112 k=knn_neighbors,
113 knn_metric=knn_metric,
114 )
115 x2y_mean = x2y_sim.mean(axis=1)
117 self.logger.info("Perform kNN in target -> source direction")
118 y2x_sim, y2x_ind = self.faiss_search.k_nearest_neighbors(
119 target_embeddings,
120 source_embeddings,
121 k=knn_neighbors,
122 knn_metric=knn_metric,
123 )
124 y2x_mean = y2x_sim.mean(axis=1)
126 self.logger.info("%s margin is selected", margin)
127 chosen_margin = self.align_method.select_margin(margin=margin)
129 self.logger.info("Compute forward and backward scores...")
130 fwd_scores = self.align_method.margin_based_score_candidates(
131 source_embeddings,
132 target_embeddings,
133 x2y_ind,
134 x2y_mean,
135 y2x_mean,
136 margin=chosen_margin,
137 )
138 bwd_scores = self.align_method.margin_based_score_candidates(
139 target_embeddings,
140 source_embeddings,
141 y2x_ind,
142 y2x_mean,
143 x2y_mean,
144 margin=chosen_margin,
145 )
147 self.logger.info("Selecting best candidates...")
148 indices, scores = self.align_method.select_best_candidates(
149 source_embeddings,
150 x2y_ind,
151 fwd_scores,
152 target_embeddings,
153 y2x_ind,
154 bwd_scores,
155 strategy=strategy,
156 )
158 bitext_list = self.align_method.get_sentence_pairs(indices, scores, source_sentences, target_sentences)
160 self.logger.info("Output sentences: pairs %d", len(bitext_list))
162 return bitext_list