Coverage for hadal/huggingface_automodel.py: 86%
74 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-09 22:55 +0300
1"""This module contains the `class HuggingfaceAutoModel` that can be used to encode text using a Huggingface AutoModel."""
2from __future__ import annotations
4import logging
5from typing import TYPE_CHECKING
7import numpy
8import torch
9from tqdm.autonotebook import trange
10from transformers import AutoConfig, AutoModel, AutoTokenizer
12from hadal.custom_logger import default_custom_logger
14if TYPE_CHECKING:
15 import pathlib
18class HuggingfaceAutoModel:
19 """Class to encode text using a Huggingface AutoModel.
21 Methods:
22 encode: Encode text using a Huggingface AutoModel.
23 """
25 def __init__(
26 self,
27 model_name_or_path: str | pathlib.Path,
28 device: str | None = None,
29 *,
30 enable_logging: bool = True,
31 log_level: int | None = logging.INFO,
32 ) -> None:
33 """Initialize HuggingfaceAutoModel object.
35 Args:
36 model_name_or_path (str | pathlib.Path): Name or path to the pre-trained model.
37 device (str | None, optional): Device for the model.
38 enable_logging (bool, optional): Logging option.
39 log_level (int | None, optional): Logging level.
40 """
41 if enable_logging is True:
42 self.logger = default_custom_logger(name=__name__, level=log_level)
43 else:
44 self.logger = logging.getLogger(__name__)
45 self.logger.disabled = True
47 if device is None:
48 device = "cuda" if torch.cuda.is_available() else "cpu"
49 self.logger.info("Pytorch device: %s", device)
51 self._target_device = torch.device(device)
53 self.model = AutoModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path)
54 self.model.eval()
55 self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path)
56 self.config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_name_or_path)
58 if model_name_or_path is not None and model_name_or_path != "":
59 self.logger.info("Load huggingface model: %s", model_name_or_path)
61 def encode(
62 self,
63 sentences: str | list[str],
64 batch_size: int = 32,
65 output_value: str = "pooler_output",
66 convert_to: str | None = None,
67 *,
68 normalize_embeddings: bool = False,
69 device: str | None = None,
70 ) -> list[torch.Tensor] | torch.Tensor | numpy.ndarray:
71 """Encode text using a Huggingface AutoModel.
73 Args:
74 sentences (str | list[str]): The sentences to encode.
75 batch_size (int, optional): The batch size.
76 output_value (str, optional): Model output type. Can be `pooler_output` or `last_hidden_state`.
77 convert_to (str | None, optional): Convert the embeddings to `torch` or `numpy` format. If `torch`, it will return a `torch.Tensor`. If `numpy`, it will return a `numpy.ndarray`. If `None`, it will return a `list[torch.Tensor]`.
78 normalize_embeddings (bool, optional): Normalize the embeddings.
79 device (str | None, optional): Device for the model.
81 Raises:
82 NotImplementedError: If the `output_value` is not implemented.
84 Returns:
85 all_embeddings (list[torch.Tensor] | torch.Tensor | numpy.ndarray): The embeddings of the sentences.
86 """
87 if device is None:
88 device = self._target_device
90 self.model.to(device)
91 self.logger.info("Encoding on pytorch device: %s", device)
93 if isinstance(sentences, str):
94 sentences = [sentences]
96 all_embeddings = []
97 length_sorted_idx = numpy.argsort([-self._text_length(sen) for sen in sentences])
98 sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
100 for start_index in trange(0, len(sentences), batch_size, desc="Batches"):
101 sentences_batch = sentences_sorted[start_index : start_index + batch_size]
102 inputs = self.tokenizer(sentences_batch, return_tensors="pt", truncation=True, padding=True)
103 inputs = batch_to_device(batch=inputs, target_device=device)
105 with torch.no_grad():
106 outputs = self.model(**inputs)
108 if output_value == "pooler_output":
109 embeddings = outputs.pooler_output
110 elif output_value == "last_hidden_state":
111 embeddings = outputs.last_hidden_state[:, 0, :]
112 else:
113 msg = f"output_value=`{output_value}` not implemented"
114 raise NotImplementedError(msg)
116 if normalize_embeddings is True:
117 # apply L2 normalization to the embeddings
118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
120 all_embeddings.extend(embeddings)
122 all_embeddings: list[torch.Tensor] = [all_embeddings[idx] for idx in numpy.argsort(length_sorted_idx)]
124 if convert_to == "torch":
125 all_embeddings: torch.Tensor = torch.stack(all_embeddings)
126 elif convert_to == "numpy":
127 all_embeddings: numpy.ndarray = torch.stack(all_embeddings).numpy()
129 return all_embeddings
131 def _text_length(self, text: list[str] | list | str) -> int:
132 """Calculate the length of the given sentences.
134 Args:
135 text (list[str] | list | str): The sentences.
137 Raises:
138 TypeError: Input cannot be a `dict`.
139 TypeError: Input cannot be a `tuple`.
141 Returns:
142 length (int): The length of the text.
143 """
144 if isinstance(text, dict):
145 msg = "Input cannot be a `dict`."
146 raise TypeError(msg)
147 if isinstance(text, tuple):
148 msg = "Input cannot be a `tuple`."
149 raise TypeError(msg)
151 if not hasattr(text, "__len__"): # no len() method
152 return 1
153 if len(text) == 0: # empty string or list
154 return len(text)
155 return sum([len(t) for t in text]) # sum of length of individual strings
158def batch_to_device(batch, target_device: torch.device): # noqa: ANN201, ANN001
159 """Move a batch of tensors to the specified device."""
160 for key in batch:
161 if isinstance(batch[key], torch.Tensor):
162 batch[key] = batch[key].to(target_device)
163 return batch