import logging from typing import Dict, List, Optional, Tuple import gradio as gr from annotation_utils import labeled_span_to_id from pie_modules.document.processing import tokenize_document from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from pytorch_ie import Pipeline from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoPipeline from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer logger = logging.getLogger(__name__) def _embed_text_annotations( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text_layer_name: str, ) -> Dict[LabeledSpan, List[float]]: # to not modify the original document document = document.copy() # tokenize_document does not yet consider predictions, so we need to add them manually document[text_layer_name].extend(document[text_layer_name].predictions.clear()) added_annotations = [] tokenizer_kwargs = { "max_length": 512, "stride": 64, "truncation": True, "return_overflowing_tokens": True, } tokenized_documents = tokenize_document( document, tokenizer=tokenizer, result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, partition_layer="labeled_partitions", added_annotations=added_annotations, strict_span_conversion=False, **tokenizer_kwargs, ) # just tokenize again to get tensors in the correct format for the model model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs) # this is added when using return_overflowing_tokens=True, but the model does not accept it model_inputs.pop("overflow_to_sample_mapping", None) assert len(model_inputs.encodings) == len(tokenized_documents) model_output = model(**model_inputs) # get embeddings for all text annotations embeddings = {} for batch_idx in range(len(model_output.last_hidden_state)): text2tok_ann = added_annotations[batch_idx][text_layer_name] tok2text_ann = {v: k for k, v in text2tok_ann.items()} for tok_ann in tokenized_documents[batch_idx].labeled_spans: # skip "empty" annotations if tok_ann.start == tok_ann.end: continue # use the max pooling strategy to get a single embedding for the annotation text embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max( dim=0 )[0] text_ann = tok2text_ann[tok_ann] if text_ann in embeddings: logger.warning( f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" ) embeddings[text_ann] = embedding return embeddings def _annotate( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, pipeline: Pipeline, embedding_model: Optional[PreTrainedModel] = None, embedding_tokenizer: Optional[PreTrainedTokenizer] = None, ) -> None: # execute prediction pipeline pipeline(document) if embedding_model is not None and embedding_tokenizer is not None: adu_embeddings = _embed_text_annotations( document=document, model=embedding_model, tokenizer=embedding_tokenizer, text_layer_name="labeled_spans", ) # convert keys to str because JSON keys must be strings adu_embeddings_dict = { labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items() } document.metadata["embeddings"] = adu_embeddings_dict else: gr.Warning( "No embedding model provided. Skipping embedding extraction. You can load an embedding " "model in the 'Model Configuration' section." ) def create_and_annotate_document( text: str, doc_id: str, models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]], ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided text, annotate it, and add it to the index. Parameters: text: The text to process. doc_id: The ID of the document. models: A tuple containing the prediction pipeline and the embedding model and tokenizer. Returns: The processed document. """ try: document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( id=doc_id, text=text, metadata={} ) # add single partition from the whole text (the model only considers text in partitions) document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) # annotate the document _annotate( document=document, pipeline=models[0], embedding_model=models[1], embedding_tokenizer=models[2], ) return document except Exception as e: raise gr.Error(f"Failed to process text: {e}") def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline: try: model = AutoPipeline.from_pretrained( model_name, device=-1, num_workers=0, taskmodule_kwargs=dict(revision=revision), model_kwargs=dict(revision=revision), ) except Exception as e: raise gr.Error(f"Failed to load argumentation model: {e}") gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})") return model def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: try: embedding_model = AutoModel.from_pretrained(model_name) embedding_tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: raise gr.Error(f"Failed to load embedding model: {e}") gr.Info(f"Loaded embedding model: model_name={model_name})") return embedding_model, embedding_tokenizer def load_models( model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None ) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]: argumentation_model = load_argumentation_model(model_name, revision) embedding_model = None embedding_tokenizer = None if embedding_model_name is not None and embedding_model_name.strip(): embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name) return argumentation_model, embedding_model, embedding_tokenizer