Spaces:
Paused
Paused
import copy | |
import os | |
import tempfile | |
import types | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union | |
import torch | |
from huggingface_hub.utils import SoftTemporaryDirectory | |
from setfit.utils import set_docstring | |
from .. import logging | |
from ..modeling import SetFitModel | |
from .aspect_extractor import AspectExtractor | |
if TYPE_CHECKING: | |
from spacy.tokens import Doc | |
logger = logging.get_logger(__name__) | |
class SpanSetFitModel(SetFitModel): | |
spacy_model: str = "en_core_web_lg" | |
span_context: int = 0 | |
attributes_to_save: Set[str] = field( | |
init=False, | |
repr=False, | |
default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"}, | |
) | |
def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]: | |
for doc, aspects in zip(docs, aspects_list): | |
for aspect_slice in aspects: | |
aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context] | |
# TODO: Investigate performance difference of different formats | |
yield aspect.text + ":" + doc.text | |
def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: | |
inputs_list = list(self.prepend_aspects(docs, aspects_list)) | |
preds = self.predict(inputs_list, as_numpy=True) | |
iter_preds = iter(preds) | |
return [[next(iter_preds) for _ in aspects] for aspects in aspects_list] | |
def create_model_card(self, path: str, model_name: Optional[str] = None) -> None: | |
"""Creates and saves a model card for a SetFit model. | |
Args: | |
path (str): The path to save the model card to. | |
model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`. | |
""" | |
if not os.path.exists(path): | |
os.makedirs(path) | |
# If the model_path is a folder that exists locally, i.e. when create_model_card is called | |
# via push_to_hub, and the path is in a temporary folder, then we only take the last two | |
# directories | |
model_path = Path(model_name) | |
if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents: | |
model_name = "/".join(model_path.parts[-2:]) | |
is_aspect = isinstance(self, AspectModel) | |
aspect_model = "setfit-absa-aspect" | |
polarity_model = "setfit-absa-polarity" | |
if model_name is not None: | |
if is_aspect: | |
aspect_model = model_name | |
if model_name.endswith("-aspect"): | |
polarity_model = model_name[: -len("-aspect")] + "-polarity" | |
else: | |
polarity_model = model_name | |
if model_name.endswith("-polarity"): | |
aspect_model = model_name[: -len("-polarity")] + "-aspect" | |
# Only once: | |
if self.model_card_data.absa is None and self.model_card_data.model_name: | |
from spacy import __version__ as spacy_version | |
self.model_card_data.model_name = self.model_card_data.model_name.replace( | |
"SetFit", "SetFit Aspect Model" if is_aspect else "SetFit Polarity Model", 1 | |
) | |
self.model_card_data.tags.insert(1, "absa") | |
self.model_card_data.version["spacy"] = spacy_version | |
self.model_card_data.absa = { | |
"is_absa": True, | |
"is_aspect": is_aspect, | |
"spacy_model": self.spacy_model, | |
"aspect_model": aspect_model, | |
"polarity_model": polarity_model, | |
} | |
if self.model_card_data.task_name is None: | |
self.model_card_data.task_name = "Aspect Based Sentiment Analysis (ABSA)" | |
self.model_card_data.inference = False | |
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: | |
f.write(self.generate_model_card()) | |
docstring = SpanSetFitModel.from_pretrained.__doc__ | |
cut_index = docstring.find("multi_target_strategy") | |
if cut_index != -1: | |
docstring = ( | |
docstring[:cut_index] | |
+ """model_card_data (`SetFitModelCardData`, *optional*): | |
A `SetFitModelCardData` instance storing data such as model language, license, dataset name, | |
etc. to be used in the automatically generated model cards. | |
use_differentiable_head (`bool`, *optional*): | |
Whether to load SetFit using a differentiable (i.e., Torch) head instead of Logistic Regression. | |
normalize_embeddings (`bool`, *optional*): | |
Whether to apply normalization on the embeddings produced by the Sentence Transformer body. | |
span_context (`int`, defaults to `0`): | |
The number of words before and after the span candidate that should be prepended to the full sentence. | |
By default, 0 for Aspect models and 3 for Polarity models. | |
device (`Union[torch.device, str]`, *optional*): | |
The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`.""" | |
) | |
SpanSetFitModel.from_pretrained = set_docstring(SpanSetFitModel.from_pretrained, docstring, cls=SpanSetFitModel) | |
class AspectModel(SpanSetFitModel): | |
def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: | |
sentence_preds = super().__call__(docs, aspects_list) | |
return [ | |
[aspect for aspect, pred in zip(aspects, preds) if pred == "aspect"] | |
for aspects, preds in zip(aspects_list, sentence_preds) | |
] | |
# The set_docstring magic has as a consequences that subclasses need to update the cls in the from_pretrained | |
# classmethod, otherwise the wrong instance will be instantiated. | |
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel) | |
class PolarityModel(SpanSetFitModel): | |
span_context: int = 3 | |
PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel) | |
class AbsaModel: | |
aspect_extractor: AspectExtractor | |
aspect_model: AspectModel | |
polarity_model: PolarityModel | |
def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: | |
is_str = isinstance(inputs, str) | |
inputs_list = [inputs] if is_str else inputs | |
docs, aspects_list = self.aspect_extractor(inputs_list) | |
if sum(aspects_list, []) == []: | |
return aspects_list | |
aspects_list = self.aspect_model(docs, aspects_list) | |
if sum(aspects_list, []) == []: | |
return aspects_list | |
polarity_list = self.polarity_model(docs, aspects_list) | |
outputs = [] | |
for docs, aspects, polarities in zip(docs, aspects_list, polarity_list): | |
outputs.append( | |
[ | |
{"span": docs[aspect_slice].text, "polarity": polarity} | |
for aspect_slice, polarity in zip(aspects, polarities) | |
] | |
) | |
return outputs if not is_str else outputs[0] | |
def device(self) -> torch.device: | |
return self.aspect_model.device | |
def to(self, device: Union[str, torch.device]) -> "AbsaModel": | |
self.aspect_model.to(device) | |
self.polarity_model.to(device) | |
def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: | |
return self.predict(inputs) | |
def save_pretrained( | |
self, | |
save_directory: Union[str, Path], | |
polarity_save_directory: Optional[Union[str, Path]] = None, | |
push_to_hub: bool = False, | |
**kwargs, | |
) -> None: | |
if polarity_save_directory is None: | |
base_save_directory = Path(save_directory) | |
save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect") | |
polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity") | |
self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) | |
self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs) | |
def from_pretrained( | |
cls, | |
model_id: str, | |
polarity_model_id: Optional[str] = None, | |
spacy_model: Optional[str] = None, | |
span_contexts: Tuple[Optional[int], Optional[int]] = (None, None), | |
force_download: bool = None, | |
resume_download: bool = None, | |
proxies: Optional[Dict] = None, | |
token: Optional[Union[str, bool]] = None, | |
cache_dir: Optional[str] = None, | |
local_files_only: bool = None, | |
use_differentiable_head: bool = None, | |
normalize_embeddings: bool = None, | |
**model_kwargs, | |
) -> "AbsaModel": | |
revision = None | |
if len(model_id.split("@")) == 2: | |
model_id, revision = model_id.split("@") | |
if spacy_model: | |
model_kwargs["spacy_model"] = spacy_model | |
aspect_model = AspectModel.from_pretrained( | |
model_id, | |
span_context=span_contexts[0], | |
revision=revision, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
token=token, | |
cache_dir=cache_dir, | |
local_files_only=local_files_only, | |
use_differentiable_head=use_differentiable_head, | |
normalize_embeddings=normalize_embeddings, | |
labels=["no aspect", "aspect"], | |
**model_kwargs, | |
) | |
if polarity_model_id: | |
model_id = polarity_model_id | |
revision = None | |
if len(model_id.split("@")) == 2: | |
model_id, revision = model_id.split("@") | |
# If model_card_data was provided, "separate" the instance between the Aspect | |
# and Polarity models. | |
model_card_data = model_kwargs.pop("model_card_data", None) | |
if model_card_data: | |
model_kwargs["model_card_data"] = copy.deepcopy(model_card_data) | |
polarity_model = PolarityModel.from_pretrained( | |
model_id, | |
span_context=span_contexts[1], | |
revision=revision, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
token=token, | |
cache_dir=cache_dir, | |
local_files_only=local_files_only, | |
use_differentiable_head=use_differentiable_head, | |
normalize_embeddings=normalize_embeddings, | |
**model_kwargs, | |
) | |
if aspect_model.spacy_model != polarity_model.spacy_model: | |
logger.warning( | |
"The Aspect and Polarity models are configured to use different spaCy models:\n" | |
f"* {repr(aspect_model.spacy_model)} for the aspect model, and\n" | |
f"* {repr(polarity_model.spacy_model)} for the polarity model.\n" | |
f"This model will use {repr(aspect_model.spacy_model)}." | |
) | |
aspect_extractor = AspectExtractor(spacy_model=aspect_model.spacy_model) | |
return cls(aspect_extractor, aspect_model, polarity_model) | |
def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: | |
if "/" not in repo_id: | |
raise ValueError( | |
'`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' | |
) | |
if polarity_repo_id is not None and "/" not in polarity_repo_id: | |
raise ValueError( | |
'`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' | |
) | |
commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model") | |
# Push the files to the repo in a single commit | |
with SoftTemporaryDirectory() as tmp_dir: | |
save_directory = Path(tmp_dir) / repo_id | |
polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id | |
self.save_pretrained( | |
save_directory=save_directory, | |
polarity_save_directory=polarity_save_directory, | |
push_to_hub=True, | |
commit_message=commit_message, | |
**kwargs, | |
) | |