Coverage for hadal/parallel_sentence_mining/margin_based/margin_based_tools.py: 85%

47 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-09 22:55 +0300

1"""The module contains the `class MarginBased` that implements the margin-based scoring for parallel sentence mining.""" 

2from __future__ import annotations 

3 

4from typing import TYPE_CHECKING 

5 

6if TYPE_CHECKING: 

7 from collections.abc import Callable 

8 

9import numpy 

10 

11 

12class MarginBased: 

13 """Class that implements the margin-based scoring for parallel sentence mining. 

14 

15 Methods: 

16 select_margin: Select the margin function. 

17 margin_based_score: Compute the margin-based score. 

18 margin_based_score_candidates: Compute the margin-based scores for a batch of sentence pairs. 

19 select_best_candidates: Select the best sentence pairs. 

20 get_sentence_pairs: Get the sentence pairs. 

21 """ 

22 

23 def __init__(self) -> None: 

24 """Initialize a MarginBased object.""" 

25 

26 def select_margin(self, margin: str = "ratio") -> Callable: 

27 """Select the margin function. 

28 

29 Source: https://arxiv.org/pdf/1811.01136.pdf 3.1 Margin-based scoring 

30 

31 Args: 

32 margin (str, optional): The margin function to use. Valid options are `ratio` and `distance`. 

33 

34 Raises: 

35 NotImplementedError: If the given `margin` is not implemented. 

36 

37 Returns: 

38 margin_func (Callable): The margin function. 

39 """ 

40 if margin == "ratio": 

41 margin_func = lambda a, b: a / b # noqa 

42 elif margin == "distance": 

43 margin_func = lambda a, b: a - b # noqa 

44 else: 

45 msg = f"margin=`{margin}` is not implemented" 

46 raise NotImplementedError(msg) 

47 return margin_func 

48 

49 def margin_based_score( 

50 self, 

51 source_embeddings: numpy.ndarray, 

52 target_embeddings: numpy.ndarray, 

53 fwd_mean: numpy.ndarray, 

54 bwd_mean: numpy.ndarray, 

55 margin_func: Callable, 

56 ) -> numpy.ndarray: 

57 """Compute the margin-based score. 

58 

59 Source: https://arxiv.org/pdf/1811.01136.pdf 3.1 Margin-based scoring 

60 

61 Args: 

62 source_embeddings (numpy.ndarray): Source embeddings. 

63 target_embeddings (numpy.ndarray): Target embeddings. 

64 fwd_mean (numpy.ndarray): The forward mean. 

65 bwd_mean (numpy.ndarray): The backward mean. 

66 margin_func (Callable): The margin function. 

67 

68 Returns: 

69 score (numpy.ndarray): Margin-based score. 

70 """ 

71 score = margin_func(source_embeddings.dot(target_embeddings), (fwd_mean + bwd_mean) / 2) 

72 

73 return score 

74 

75 def margin_based_score_candidates( 

76 self, 

77 source_embeddings: numpy.ndarray, 

78 target_embeddings: numpy.ndarray, 

79 candidate_inds: numpy.ndarray, 

80 fwd_mean: numpy.ndarray, 

81 bwd_mean: numpy.ndarray, 

82 margin: Callable, 

83 ) -> numpy.ndarray: 

84 """Compute the margin-based scores for a batch of sentence pairs. 

85 

86 Args: 

87 source_embeddings (numpy.ndarray): Source embeddings. 

88 target_embeddings (numpy.ndarray): Target embeddings. 

89 candidate_inds (numpy.ndarray): The indices of the candidate target embeddings for each source embedding. 

90 fwd_mean (numpy.ndarray): The forward mean. 

91 bwd_mean (numpy.ndarray): The backward mean. 

92 margin (Callable): The margin function. 

93 

94 Returns: 

95 scores (numpy.ndarray): The margin-based scores for the candidate pairs. 

96 """ 

97 scores = numpy.zeros(candidate_inds.shape) 

98 for i in range(scores.shape[0]): 

99 for j in range(scores.shape[1]): 

100 k = candidate_inds[i, j] 

101 scores[i, j] = self.margin_based_score( 

102 source_embeddings[i], 

103 target_embeddings[k], 

104 fwd_mean[i], 

105 bwd_mean[k], 

106 margin, 

107 ) 

108 return scores 

109 

110 def select_best_candidates( 

111 self, 

112 source_embeddings: numpy.ndarray, 

113 x2y_ind: numpy.ndarray, 

114 fwd_scores: numpy.ndarray, 

115 target_embeddings: numpy.ndarray, 

116 y2x_ind: numpy.ndarray, 

117 bwd_scores: numpy.ndarray, 

118 strategy: str = "max_score", 

119 ) -> tuple[numpy.ndarray, numpy.ndarray]: 

120 """Select the best sentence pairs. 

121 

122 Source: https://arxiv.org/pdf/1811.01136.pdf 3.2 Candidate generation and filtering (only max. score) 

123 

124 Args: 

125 source_embeddings (numpy.ndarray): Source embeddings. 

126 x2y_ind (numpy.ndarray): Indices of the target sentences corresponding to each source sentence. 

127 fwd_scores (numpy.ndarray): Scores of the forward alignment between source and target sentences. 

128 target_embeddings (numpy.ndarray): Target embeddings. 

129 y2x_ind (numpy.ndarray): Indices of the source sentences corresponding to each target sentence. 

130 bwd_scores (numpy.ndarray): Scores of the backward alignment between target and source sentences. 

131 strategy (str, optional): The strategy to use for selecting the best candidates. 

132 

133 Raises: 

134 NotImplementedError: If the given `strategy` is not implemented. 

135 

136 Returns: 

137 - indices (numpy.ndarray): An array of indices representing the sentence pairs. 

138 - scores (numpy.ndarray): An array of scores representing the similarity between the sentence pairs. 

139 """ 

140 if strategy == "max_score": 

141 fwd_best = x2y_ind[numpy.arange(source_embeddings.shape[0]), fwd_scores.argmax(axis=1)] 

142 bwd_best = y2x_ind[numpy.arange(target_embeddings.shape[0]), bwd_scores.argmax(axis=1)] 

143 

144 indices = numpy.stack( 

145 [ 

146 numpy.concatenate([numpy.arange(source_embeddings.shape[0]), bwd_best]), 

147 numpy.concatenate([fwd_best, numpy.arange(target_embeddings.shape[0])]), 

148 ], 

149 axis=1, 

150 ) 

151 scores = numpy.concatenate([fwd_scores.max(axis=1), bwd_scores.max(axis=1)]) 

152 

153 else: 

154 msg = f"`{strategy}` is not implemented" 

155 raise NotImplementedError(msg) 

156 

157 return indices, scores 

158 

159 def get_sentence_pairs( 

160 self, 

161 indices: numpy.ndarray, 

162 scores: numpy.ndarray, 

163 source_sentences: list[str], 

164 target_sentences: list[str], 

165 ) -> list[tuple[numpy.float64, str, str]]: 

166 """Get the sentence pairs. 

167 

168 Args: 

169 indices (numpy.ndarray): An array of indices representing the sentence pairs. 

170 scores (numpy.ndarray): An array of scores representing the similarity between the sentence pairs. 

171 source_sentences (list[str]): Source sentences. 

172 target_sentences (list[str]): Target sentences. 

173 

174 Returns: 

175 bitext_list (list[tuple[numpy.float64, str, str]]): A list of tuples with score, source sentences and target sentences. 

176 """ 

177 seen_src, seen_trg = set(), set() 

178 

179 bitext_list = [] 

180 

181 for i in numpy.argsort(-scores): 

182 src_ind, trg_ind = indices[i] 

183 src_ind = int(src_ind) 

184 trg_ind = int(trg_ind) 

185 

186 if src_ind not in seen_src and trg_ind not in seen_trg: 

187 seen_src.add(src_ind) 

188 seen_trg.add(trg_ind) 

189 rounded_score = numpy.round(scores[i], 4) 

190 bitext_list.append((rounded_score, source_sentences[src_ind], target_sentences[trg_ind])) 

191 

192 return bitext_list