Coverage for hadal/parallel_sentence_mining/margin_based/margin_based_tools.py: 85%
47 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
1"""The module contains the `class MarginBased` that implements the margin-based scoring for parallel sentence mining."""
2from __future__ import annotations
4from typing import TYPE_CHECKING
6if TYPE_CHECKING:
7 from collections.abc import Callable
9import numpy
12class MarginBased:
13 """Class that implements the margin-based scoring for parallel sentence mining.
15 Methods:
16 select_margin: Select the margin function.
17 margin_based_score: Compute the margin-based score.
18 margin_based_score_candidates: Compute the margin-based scores for a batch of sentence pairs.
19 select_best_candidates: Select the best sentence pairs.
20 get_sentence_pairs: Get the sentence pairs.
21 """
23 def __init__(self) -> None:
24 """Initialize a MarginBased object."""
26 def select_margin(self, margin: str = "ratio") -> Callable:
27 """Select the margin function.
29 Source: https://arxiv.org/pdf/1811.01136.pdf 3.1 Margin-based scoring
31 Args:
32 margin (str, optional): The margin function to use. Valid options are `ratio` and `distance`.
34 Raises:
35 NotImplementedError: If the given `margin` is not implemented.
37 Returns:
38 margin_func (Callable): The margin function.
39 """
40 if margin == "ratio":
41 margin_func = lambda a, b: a / b # noqa
42 elif margin == "distance":
43 margin_func = lambda a, b: a - b # noqa
44 else:
45 msg = f"margin=`{margin}` is not implemented"
46 raise NotImplementedError(msg)
47 return margin_func
49 def margin_based_score(
50 self,
51 source_embeddings: numpy.ndarray,
52 target_embeddings: numpy.ndarray,
53 fwd_mean: numpy.ndarray,
54 bwd_mean: numpy.ndarray,
55 margin_func: Callable,
56 ) -> numpy.ndarray:
57 """Compute the margin-based score.
59 Source: https://arxiv.org/pdf/1811.01136.pdf 3.1 Margin-based scoring
61 Args:
62 source_embeddings (numpy.ndarray): Source embeddings.
63 target_embeddings (numpy.ndarray): Target embeddings.
64 fwd_mean (numpy.ndarray): The forward mean.
65 bwd_mean (numpy.ndarray): The backward mean.
66 margin_func (Callable): The margin function.
68 Returns:
69 score (numpy.ndarray): Margin-based score.
70 """
71 score = margin_func(source_embeddings.dot(target_embeddings), (fwd_mean + bwd_mean) / 2)
73 return score
75 def margin_based_score_candidates(
76 self,
77 source_embeddings: numpy.ndarray,
78 target_embeddings: numpy.ndarray,
79 candidate_inds: numpy.ndarray,
80 fwd_mean: numpy.ndarray,
81 bwd_mean: numpy.ndarray,
82 margin: Callable,
83 ) -> numpy.ndarray:
84 """Compute the margin-based scores for a batch of sentence pairs.
86 Args:
87 source_embeddings (numpy.ndarray): Source embeddings.
88 target_embeddings (numpy.ndarray): Target embeddings.
89 candidate_inds (numpy.ndarray): The indices of the candidate target embeddings for each source embedding.
90 fwd_mean (numpy.ndarray): The forward mean.
91 bwd_mean (numpy.ndarray): The backward mean.
92 margin (Callable): The margin function.
94 Returns:
95 scores (numpy.ndarray): The margin-based scores for the candidate pairs.
96 """
97 scores = numpy.zeros(candidate_inds.shape)
98 for i in range(scores.shape[0]):
99 for j in range(scores.shape[1]):
100 k = candidate_inds[i, j]
101 scores[i, j] = self.margin_based_score(
102 source_embeddings[i],
103 target_embeddings[k],
104 fwd_mean[i],
105 bwd_mean[k],
106 margin,
107 )
108 return scores
110 def select_best_candidates(
111 self,
112 source_embeddings: numpy.ndarray,
113 x2y_ind: numpy.ndarray,
114 fwd_scores: numpy.ndarray,
115 target_embeddings: numpy.ndarray,
116 y2x_ind: numpy.ndarray,
117 bwd_scores: numpy.ndarray,
118 strategy: str = "max_score",
119 ) -> tuple[numpy.ndarray, numpy.ndarray]:
120 """Select the best sentence pairs.
122 Source: https://arxiv.org/pdf/1811.01136.pdf 3.2 Candidate generation and filtering (only max. score)
124 Args:
125 source_embeddings (numpy.ndarray): Source embeddings.
126 x2y_ind (numpy.ndarray): Indices of the target sentences corresponding to each source sentence.
127 fwd_scores (numpy.ndarray): Scores of the forward alignment between source and target sentences.
128 target_embeddings (numpy.ndarray): Target embeddings.
129 y2x_ind (numpy.ndarray): Indices of the source sentences corresponding to each target sentence.
130 bwd_scores (numpy.ndarray): Scores of the backward alignment between target and source sentences.
131 strategy (str, optional): The strategy to use for selecting the best candidates.
133 Raises:
134 NotImplementedError: If the given `strategy` is not implemented.
136 Returns:
137 - indices (numpy.ndarray): An array of indices representing the sentence pairs.
138 - scores (numpy.ndarray): An array of scores representing the similarity between the sentence pairs.
139 """
140 if strategy == "max_score":
141 fwd_best = x2y_ind[numpy.arange(source_embeddings.shape[0]), fwd_scores.argmax(axis=1)]
142 bwd_best = y2x_ind[numpy.arange(target_embeddings.shape[0]), bwd_scores.argmax(axis=1)]
144 indices = numpy.stack(
145 [
146 numpy.concatenate([numpy.arange(source_embeddings.shape[0]), bwd_best]),
147 numpy.concatenate([fwd_best, numpy.arange(target_embeddings.shape[0])]),
148 ],
149 axis=1,
150 )
151 scores = numpy.concatenate([fwd_scores.max(axis=1), bwd_scores.max(axis=1)])
153 else:
154 msg = f"`{strategy}` is not implemented"
155 raise NotImplementedError(msg)
157 return indices, scores
159 def get_sentence_pairs(
160 self,
161 indices: numpy.ndarray,
162 scores: numpy.ndarray,
163 source_sentences: list[str],
164 target_sentences: list[str],
165 ) -> list[tuple[numpy.float64, str, str]]:
166 """Get the sentence pairs.
168 Args:
169 indices (numpy.ndarray): An array of indices representing the sentence pairs.
170 scores (numpy.ndarray): An array of scores representing the similarity between the sentence pairs.
171 source_sentences (list[str]): Source sentences.
172 target_sentences (list[str]): Target sentences.
174 Returns:
175 bitext_list (list[tuple[numpy.float64, str, str]]): A list of tuples with score, source sentences and target sentences.
176 """
177 seen_src, seen_trg = set(), set()
179 bitext_list = []
181 for i in numpy.argsort(-scores):
182 src_ind, trg_ind = indices[i]
183 src_ind = int(src_ind)
184 trg_ind = int(trg_ind)
186 if src_ind not in seen_src and trg_ind not in seen_trg:
187 seen_src.add(src_ind)
188 seen_trg.add(trg_ind)
189 rounded_score = numpy.round(scores[i], 4)
190 bitext_list.append((rounded_score, source_sentences[src_ind], target_sentences[trg_ind]))
192 return bitext_list