Coverage for hadal/parallel_sentence_mining/margin_based/margin_based.py: 91%

43 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-09 22:55 +0300

1"""This module contains the `class MarginBasedPipeline` that implements the margin-based pipeline.""" 

2from __future__ import annotations 

3 

4import logging 

5from typing import TYPE_CHECKING 

6 

7from hadal.custom_logger import default_custom_logger 

8from hadal.faiss_search import FaissSearch 

9from hadal.huggingface_automodel import HuggingfaceAutoModel 

10from hadal.parallel_sentence_mining.margin_based.margin_based_tools import MarginBased 

11 

12if TYPE_CHECKING: 

13 import pathlib 

14 

15 import numpy 

16 

17 

18class MarginBasedPipeline: 

19 """Class that implements the margin-based pipeline. 

20 

21 Methods: 

22 make_alignments: Make sentence alignments. 

23 """ 

24 

25 def __init__( 

26 self, 

27 model_name_or_path: str | pathlib.Path, 

28 model_device: str | None = None, 

29 faiss_device: str | None = None, 

30 *, 

31 enable_logging: bool = True, 

32 log_level: int | None = logging.INFO, 

33 ) -> None: 

34 """Initialize a MarginBasedPipeline object. 

35 

36 Args: 

37 model_name_or_path (str | pathlib.Path): Name or path to the pre-trained model. 

38 model_device (str | None, optional): Device for the model. 

39 faiss_device (str | None, optional): Device for the Faiss search. If `None`, it will use GPU if available, otherwise CPU. 

40 enable_logging (bool, optional): Logging option. 

41 log_level (int | None, optional): Logging level. 

42 """ 

43 self.model = HuggingfaceAutoModel( 

44 model_name_or_path=model_name_or_path, 

45 device=model_device, 

46 enable_logging=enable_logging, 

47 ) 

48 self.align_method = MarginBased() 

49 self.faiss_search = FaissSearch(device=faiss_device, enable_logging=enable_logging) 

50 self.model_device = model_device 

51 self.faiss_device = faiss_device 

52 

53 if enable_logging is True: 

54 self.logger = default_custom_logger(name=__name__, level=log_level) 

55 else: 

56 self.logger = logging.getLogger(__name__) 

57 self.logger.disabled = True 

58 

59 def make_alignments( 

60 self, 

61 source_sentences: list[str], 

62 target_sentences: list[str], 

63 batch_size: int = 32, 

64 output_value: str = "pooler_output", 

65 convert_to: str = "numpy", 

66 *, 

67 normalize_embeddings: bool = True, 

68 knn_neighbors: int = 4, 

69 knn_metric: str = "inner_product", 

70 margin: str = "ratio", 

71 strategy: str = "max_score", 

72 ) -> list[tuple[numpy.float64, str, str]]: 

73 """Make sentence alignments. 

74 

75 Args: 

76 source_sentences (list[str]): Source sentences. 

77 target_sentences (list[str]): Target sentences. 

78 batch_size (int, optional): The batch size. 

79 output_value (str, optional): Model output type. Can be `pooler_output` or `last_hidden_state`. 

80 convert_to (str, optional): Convert the embeddings to `torch` or `numpy` format. If `torch`, it will return a `torch.Tensor`. If `numpy`, it will return a `numpy.ndarray`. If `None`, it will return a `list[torch.Tensor]`. 

81 normalize_embeddings (bool, optional): Normalize the embeddings. 

82 knn_neighbors (int, optional): The number of nearest neighbors. 

83 knn_metric (str, optional): The metric to use for k-nearest neighbor search. Can be `inner_product` or `l2`. 

84 margin (str, optional): The margin function to use. Valid options are `ratio` and `distance`. 

85 strategy (str, optional): The strategy to use for selecting the best candidates. 

86 

87 Returns: 

88 bitext_list (list[tuple[numpy.float64, str, str]]): The `list[tuple[score, source_sentence, target_sentence]]` of the best sentence alignments. 

89 """ 

90 self.logger.info("Encoding embeddings for source sentences...") 

91 source_embeddings = self.model.encode( 

92 sentences=source_sentences, 

93 batch_size=batch_size, 

94 output_value=output_value, 

95 convert_to=convert_to, 

96 normalize_embeddings=normalize_embeddings, 

97 ) 

98 self.logger.info("Encoding embeddings for target sentences...") 

99 target_embeddings = self.model.encode( 

100 sentences=target_sentences, 

101 batch_size=batch_size, 

102 output_value=output_value, 

103 convert_to=convert_to, 

104 normalize_embeddings=normalize_embeddings, 

105 ) 

106 

107 self.logger.info("Perform kNN in both directions...") 

108 self.logger.info("Perform kNN in source -> target direction") 

109 x2y_sim, x2y_ind = self.faiss_search.k_nearest_neighbors( 

110 source_embeddings, 

111 target_embeddings, 

112 k=knn_neighbors, 

113 knn_metric=knn_metric, 

114 ) 

115 x2y_mean = x2y_sim.mean(axis=1) 

116 

117 self.logger.info("Perform kNN in target -> source direction") 

118 y2x_sim, y2x_ind = self.faiss_search.k_nearest_neighbors( 

119 target_embeddings, 

120 source_embeddings, 

121 k=knn_neighbors, 

122 knn_metric=knn_metric, 

123 ) 

124 y2x_mean = y2x_sim.mean(axis=1) 

125 

126 self.logger.info("%s margin is selected", margin) 

127 chosen_margin = self.align_method.select_margin(margin=margin) 

128 

129 self.logger.info("Compute forward and backward scores...") 

130 fwd_scores = self.align_method.margin_based_score_candidates( 

131 source_embeddings, 

132 target_embeddings, 

133 x2y_ind, 

134 x2y_mean, 

135 y2x_mean, 

136 margin=chosen_margin, 

137 ) 

138 bwd_scores = self.align_method.margin_based_score_candidates( 

139 target_embeddings, 

140 source_embeddings, 

141 y2x_ind, 

142 y2x_mean, 

143 x2y_mean, 

144 margin=chosen_margin, 

145 ) 

146 

147 self.logger.info("Selecting best candidates...") 

148 indices, scores = self.align_method.select_best_candidates( 

149 source_embeddings, 

150 x2y_ind, 

151 fwd_scores, 

152 target_embeddings, 

153 y2x_ind, 

154 bwd_scores, 

155 strategy=strategy, 

156 ) 

157 

158 bitext_list = self.align_method.get_sentence_pairs(indices, scores, source_sentences, target_sentences) 

159 

160 self.logger.info("Output sentences: pairs %d", len(bitext_list)) 

161 

162 return bitext_list