Coverage for hadal/huggingface_automodel.py: 86%

74 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-09 22:55 +0300

1"""This module contains the `class HuggingfaceAutoModel` that can be used to encode text using a Huggingface AutoModel.""" 

2from __future__ import annotations 

3 

4import logging 

5from typing import TYPE_CHECKING 

6 

7import numpy 

8import torch 

9from tqdm.autonotebook import trange 

10from transformers import AutoConfig, AutoModel, AutoTokenizer 

11 

12from hadal.custom_logger import default_custom_logger 

13 

14if TYPE_CHECKING: 

15 import pathlib 

16 

17 

18class HuggingfaceAutoModel: 

19 """Class to encode text using a Huggingface AutoModel. 

20 

21 Methods: 

22 encode: Encode text using a Huggingface AutoModel. 

23 """ 

24 

25 def __init__( 

26 self, 

27 model_name_or_path: str | pathlib.Path, 

28 device: str | None = None, 

29 *, 

30 enable_logging: bool = True, 

31 log_level: int | None = logging.INFO, 

32 ) -> None: 

33 """Initialize HuggingfaceAutoModel object. 

34 

35 Args: 

36 model_name_or_path (str | pathlib.Path): Name or path to the pre-trained model. 

37 device (str | None, optional): Device for the model. 

38 enable_logging (bool, optional): Logging option. 

39 log_level (int | None, optional): Logging level. 

40 """ 

41 if enable_logging is True: 

42 self.logger = default_custom_logger(name=__name__, level=log_level) 

43 else: 

44 self.logger = logging.getLogger(__name__) 

45 self.logger.disabled = True 

46 

47 if device is None: 

48 device = "cuda" if torch.cuda.is_available() else "cpu" 

49 self.logger.info("Pytorch device: %s", device) 

50 

51 self._target_device = torch.device(device) 

52 

53 self.model = AutoModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path) 

54 self.model.eval() 

55 self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path) 

56 self.config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_name_or_path) 

57 

58 if model_name_or_path is not None and model_name_or_path != "": 

59 self.logger.info("Load huggingface model: %s", model_name_or_path) 

60 

61 def encode( 

62 self, 

63 sentences: str | list[str], 

64 batch_size: int = 32, 

65 output_value: str = "pooler_output", 

66 convert_to: str | None = None, 

67 *, 

68 normalize_embeddings: bool = False, 

69 device: str | None = None, 

70 ) -> list[torch.Tensor] | torch.Tensor | numpy.ndarray: 

71 """Encode text using a Huggingface AutoModel. 

72 

73 Args: 

74 sentences (str | list[str]): The sentences to encode. 

75 batch_size (int, optional): The batch size. 

76 output_value (str, optional): Model output type. Can be `pooler_output` or `last_hidden_state`. 

77 convert_to (str | None, optional): Convert the embeddings to `torch` or `numpy` format. If `torch`, it will return a `torch.Tensor`. If `numpy`, it will return a `numpy.ndarray`. If `None`, it will return a `list[torch.Tensor]`. 

78 normalize_embeddings (bool, optional): Normalize the embeddings. 

79 device (str | None, optional): Device for the model. 

80 

81 Raises: 

82 NotImplementedError: If the `output_value` is not implemented. 

83 

84 Returns: 

85 all_embeddings (list[torch.Tensor] | torch.Tensor | numpy.ndarray): The embeddings of the sentences. 

86 """ 

87 if device is None: 

88 device = self._target_device 

89 

90 self.model.to(device) 

91 self.logger.info("Encoding on pytorch device: %s", device) 

92 

93 if isinstance(sentences, str): 

94 sentences = [sentences] 

95 

96 all_embeddings = [] 

97 length_sorted_idx = numpy.argsort([-self._text_length(sen) for sen in sentences]) 

98 sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 

99 

100 for start_index in trange(0, len(sentences), batch_size, desc="Batches"): 

101 sentences_batch = sentences_sorted[start_index : start_index + batch_size] 

102 inputs = self.tokenizer(sentences_batch, return_tensors="pt", truncation=True, padding=True) 

103 inputs = batch_to_device(batch=inputs, target_device=device) 

104 

105 with torch.no_grad(): 

106 outputs = self.model(**inputs) 

107 

108 if output_value == "pooler_output": 

109 embeddings = outputs.pooler_output 

110 elif output_value == "last_hidden_state": 

111 embeddings = outputs.last_hidden_state[:, 0, :] 

112 else: 

113 msg = f"output_value=`{output_value}` not implemented" 

114 raise NotImplementedError(msg) 

115 

116 if normalize_embeddings is True: 

117 # apply L2 normalization to the embeddings 

118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 

119 

120 all_embeddings.extend(embeddings) 

121 

122 all_embeddings: list[torch.Tensor] = [all_embeddings[idx] for idx in numpy.argsort(length_sorted_idx)] 

123 

124 if convert_to == "torch": 

125 all_embeddings: torch.Tensor = torch.stack(all_embeddings) 

126 elif convert_to == "numpy": 

127 all_embeddings: numpy.ndarray = torch.stack(all_embeddings).numpy() 

128 

129 return all_embeddings 

130 

131 def _text_length(self, text: list[str] | list | str) -> int: 

132 """Calculate the length of the given sentences. 

133 

134 Args: 

135 text (list[str] | list | str): The sentences. 

136 

137 Raises: 

138 TypeError: Input cannot be a `dict`. 

139 TypeError: Input cannot be a `tuple`. 

140 

141 Returns: 

142 length (int): The length of the text. 

143 """ 

144 if isinstance(text, dict): 

145 msg = "Input cannot be a `dict`." 

146 raise TypeError(msg) 

147 if isinstance(text, tuple): 

148 msg = "Input cannot be a `tuple`." 

149 raise TypeError(msg) 

150 

151 if not hasattr(text, "__len__"): # no len() method 

152 return 1 

153 if len(text) == 0: # empty string or list 

154 return len(text) 

155 return sum([len(t) for t in text]) # sum of length of individual strings 

156 

157 

158def batch_to_device(batch, target_device: torch.device): # noqa: ANN201, ANN001 

159 """Move a batch of tensors to the specified device.""" 

160 for key in batch: 

161 if isinstance(batch[key], torch.Tensor): 

162 batch[key] = batch[key].to(target_device) 

163 return batch