import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple import gradio as gr import pandas as pd from pie_modules.document.processing import tokenize_document from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from pytorch_ie import Pipeline from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from rendering_utils import labeled_span_to_id from transformers import PreTrainedModel, PreTrainedTokenizer from vector_store import VectorStore logger = logging.getLogger(__name__) def embed_text_annotations( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text_layer_name: str, ) -> Dict[Span, List[float]]: # to not modify the original document document = document.copy() # tokenize_document does not yet consider predictions, so we need to add them manually document[text_layer_name].extend(document[text_layer_name].predictions.clear()) added_annotations = [] tokenizer_kwargs = { "max_length": 512, "stride": 64, "truncation": True, "return_overflowing_tokens": True, } tokenized_documents = tokenize_document( document, tokenizer=tokenizer, result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, partition_layer="labeled_partitions", added_annotations=added_annotations, strict_span_conversion=False, **tokenizer_kwargs, ) # just tokenize again to get tensors in the correct format for the model model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs) # this is added when using return_overflowing_tokens=True, but the model does not accept it model_inputs.pop("overflow_to_sample_mapping", None) assert len(model_inputs.encodings) == len(tokenized_documents) model_output = model(**model_inputs) # get embeddings for all text annotations embeddings = {} for batch_idx in range(len(model_output.last_hidden_state)): text2tok_ann = added_annotations[batch_idx][text_layer_name] tok2text_ann = {v: k for k, v in text2tok_ann.items()} for tok_ann in tokenized_documents[batch_idx].labeled_spans: # skip "empty" annotations if tok_ann.start == tok_ann.end: continue # use the max pooling strategy to get a single embedding for the annotation text embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max( dim=0 )[0] text_ann = tok2text_ann[tok_ann] if text_ann in embeddings: logger.warning( f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" ) embeddings[text_ann] = embedding return embeddings def annotate( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, pipeline: Pipeline, embedding_model: Optional[PreTrainedModel] = None, embedding_tokenizer: Optional[PreTrainedTokenizer] = None, ) -> None: # execute prediction pipeline pipeline(document) if embedding_model is not None and embedding_tokenizer is not None: adu_embeddings = embed_text_annotations( document=document, model=embedding_model, tokenizer=embedding_tokenizer, text_layer_name="labeled_spans", ) # convert keys to str because JSON keys must be strings adu_embeddings_dict = { labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items() } document.metadata["embeddings"] = adu_embeddings_dict else: gr.Warning( "No embedding model provided. Skipping embedding extraction. You can load an embedding " "model in the 'Model Configuration' section." ) def add_to_index( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, processed_documents: dict, vector_store: VectorStore[Tuple[str, str]], ) -> None: try: if document.id in processed_documents: gr.Warning(f"Document '{document.id}' already in index. Overwriting.") # save the processed document to the index processed_documents[document.id] = document # save the embeddings to the vector store for adu_id, embedding in document.metadata["embeddings"].items(): vector_store.save((document.id, adu_id), embedding) gr.Info( f"Added document {document.id} to index (index contains {len(processed_documents)} " f"documents and {len(vector_store)} embeddings)." ) except Exception as e: raise gr.Error(f"Failed to add document {document.id} to index: {e}") def process_text( text: str, doc_id: str, models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]], processed_documents: dict[ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ], vector_store: VectorStore[Tuple[str, str]], ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided text, annotate it, and add it to the index. Parameters: text: The text to process. doc_id: The ID of the document. models: A tuple containing the prediction pipeline and the embedding model and tokenizer. processed_documents: The index of processed documents. vector_store: The vector store to save the embeddings. Returns: The processed document. """ try: document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( id=doc_id, text=text, metadata={} ) # add single partition from the whole text (the model only considers text in partitions) document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) # annotate the document annotate( document=document, pipeline=models[0], embedding_model=models[1], embedding_tokenizer=models[2], ) # add the document to the index add_to_index(document, processed_documents, vector_store) return document except Exception as e: raise gr.Error(f"Failed to process text: {e}") def get_annotation_from_document( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, annotation_id: str, annotation_layer: str, ) -> LabeledSpan: # use predictions annotations = document[annotation_layer].predictions id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations} annotation = id2annotation.get(annotation_id) if annotation is None: raise gr.Error( f"annotation '{annotation_id}' not found in document '{document.id}'. Available " f"annotations: {id2annotation}" ) return annotation def get_annotation_from_processed_documents( doc_id: str, annotation_id: str, annotation_layer: str, processed_documents: dict[ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ], ) -> LabeledSpan: document = processed_documents.get(doc_id) if document is None: raise gr.Error( f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}" ) return get_annotation_from_document(document, annotation_id, annotation_layer) def get_similar_adus( ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, vector_store: VectorStore[Tuple[str, str]], processed_documents: dict[ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ], min_similarity: float, top_k: int, ) -> pd.DataFrame: similar_entries = vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) similar_annotations = [ get_annotation_from_processed_documents( doc_id=doc_id, annotation_id=annotation_id, annotation_layer="labeled_spans", processed_documents=processed_documents, ) for (doc_id, annotation_id), _ in similar_entries ] df = pd.DataFrame( [ # unpack the tuple (doc_id, annotation_id) to separate columns # and add the similarity score and the text of the annotation (doc_id, annotation_id, score, str(annotation)) for ((doc_id, annotation_id), score), annotation in zip( similar_entries, similar_annotations ) ], columns=["doc_id", "adu_id", "sim_score", "text"], ) return df def get_relevant_adus( ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, vector_store: VectorStore[Tuple[str, str]], processed_documents: dict[ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ], min_similarity: float, top_k: int, relation_types: List[str], ) -> pd.DataFrame: similar_entries = vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) result = [] for (doc_id, annotation_id), score in similar_entries: # skip entries from the same document if doc_id == ref_document.id: continue document = processed_documents[doc_id] tail2rels = defaultdict(list) head2rels = defaultdict(list) for rel in document.binary_relations.predictions: # skip non-argumentative relations if rel.label not in relation_types: continue head2rels[rel.head].append(rel) tail2rels[rel.tail].append(rel) id2annotation = { labeled_span_to_id(annotation): annotation for annotation in document.labeled_spans.predictions } annotation = id2annotation.get(annotation_id) # note: we do not need to check if the annotation is different from the reference annotation, # because they come from different documents and we already skip entries from the same document for rel in head2rels.get(annotation, []): result.append( { "doc_id": doc_id, "reference_adu": str(annotation), "sim_score": score, "rel_score": rel.score, "relation": rel.label, "text": str(rel.tail), } ) # define column order df = pd.DataFrame( result, columns=["text", "relation", "doc_id", "reference_adu", "sim_score", "rel_score"] ) return df