|
import logging |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from pie_modules.document.processing import tokenize_document |
|
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from pytorch_ie import Pipeline |
|
from pytorch_ie.annotations import LabeledSpan, Span |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from rendering_utils import labeled_span_to_id |
|
from transformers import PreTrainedModel, PreTrainedTokenizer |
|
from vector_store import SimpleVectorStore, VectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _embed_text_annotations( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
model: PreTrainedModel, |
|
tokenizer: PreTrainedTokenizer, |
|
text_layer_name: str, |
|
) -> Dict[Span, List[float]]: |
|
|
|
document = document.copy() |
|
|
|
document[text_layer_name].extend(document[text_layer_name].predictions.clear()) |
|
added_annotations = [] |
|
tokenizer_kwargs = { |
|
"max_length": 512, |
|
"stride": 64, |
|
"truncation": True, |
|
"return_overflowing_tokens": True, |
|
} |
|
tokenized_documents = tokenize_document( |
|
document, |
|
tokenizer=tokenizer, |
|
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
partition_layer="labeled_partitions", |
|
added_annotations=added_annotations, |
|
strict_span_conversion=False, |
|
**tokenizer_kwargs, |
|
) |
|
|
|
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs) |
|
|
|
model_inputs.pop("overflow_to_sample_mapping", None) |
|
assert len(model_inputs.encodings) == len(tokenized_documents) |
|
model_output = model(**model_inputs) |
|
|
|
|
|
embeddings = {} |
|
for batch_idx in range(len(model_output.last_hidden_state)): |
|
text2tok_ann = added_annotations[batch_idx][text_layer_name] |
|
tok2text_ann = {v: k for k, v in text2tok_ann.items()} |
|
for tok_ann in tokenized_documents[batch_idx].labeled_spans: |
|
|
|
if tok_ann.start == tok_ann.end: |
|
continue |
|
|
|
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max( |
|
dim=0 |
|
)[0] |
|
text_ann = tok2text_ann[tok_ann] |
|
|
|
if text_ann in embeddings: |
|
logger.warning( |
|
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" |
|
) |
|
embeddings[text_ann] = embedding |
|
|
|
return embeddings |
|
|
|
|
|
def _annotate( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
pipeline: Pipeline, |
|
embedding_model: Optional[PreTrainedModel] = None, |
|
embedding_tokenizer: Optional[PreTrainedTokenizer] = None, |
|
) -> None: |
|
|
|
|
|
pipeline(document) |
|
|
|
if embedding_model is not None and embedding_tokenizer is not None: |
|
adu_embeddings = _embed_text_annotations( |
|
document=document, |
|
model=embedding_model, |
|
tokenizer=embedding_tokenizer, |
|
text_layer_name="labeled_spans", |
|
) |
|
|
|
adu_embeddings_dict = { |
|
labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items() |
|
} |
|
document.metadata["embeddings"] = adu_embeddings_dict |
|
else: |
|
gr.Warning( |
|
"No embedding model provided. Skipping embedding extraction. You can load an embedding " |
|
"model in the 'Model Configuration' section." |
|
) |
|
|
|
|
|
def create_and_annotate_document( |
|
text: str, |
|
doc_id: str, |
|
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]], |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided |
|
text, annotate it, and add it to the index. |
|
|
|
Parameters: |
|
text: The text to process. |
|
doc_id: The ID of the document. |
|
models: A tuple containing the prediction pipeline and the embedding model and tokenizer. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
try: |
|
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( |
|
id=doc_id, text=text, metadata={} |
|
) |
|
|
|
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) |
|
|
|
_annotate( |
|
document=document, |
|
pipeline=models[0], |
|
embedding_model=models[1], |
|
embedding_tokenizer=models[2], |
|
) |
|
|
|
return document |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process text: {e}") |
|
|
|
|
|
def get_annotation_from_document( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
) -> LabeledSpan: |
|
|
|
annotations = document[annotation_layer].predictions |
|
id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations} |
|
annotation = id2annotation.get(annotation_id) |
|
if annotation is None: |
|
raise gr.Error( |
|
f"annotation '{annotation_id}' not found in document '{document.id}'. Available " |
|
f"annotations: {id2annotation}" |
|
) |
|
return annotation |
|
|
|
|
|
class DocumentStore: |
|
|
|
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
|
|
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str]]] = None): |
|
self.documents = {} |
|
self.vector_store = vector_store or SimpleVectorStore() |
|
|
|
def get_annotation( |
|
self, |
|
doc_id: str, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
) -> LabeledSpan: |
|
document = self.documents.get(doc_id) |
|
if document is None: |
|
raise gr.Error( |
|
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" |
|
) |
|
return get_annotation_from_document(document, annotation_id, annotation_layer) |
|
|
|
def get_similar_adus_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
min_similarity: float, |
|
top_k: int, |
|
) -> pd.DataFrame: |
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
|
|
similar_annotations = [ |
|
self.get_annotation( |
|
doc_id=doc_id, |
|
annotation_id=annotation_id, |
|
annotation_layer="labeled_spans", |
|
) |
|
for (doc_id, annotation_id), _ in similar_entries |
|
] |
|
df = pd.DataFrame( |
|
[ |
|
|
|
|
|
(doc_id, annotation_id, score, str(annotation)) |
|
for ((doc_id, annotation_id), score), annotation in zip( |
|
similar_entries, similar_annotations |
|
) |
|
], |
|
columns=["doc_id", "adu_id", "sim_score", "text"], |
|
) |
|
|
|
return df |
|
|
|
def get_relevant_adus_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
min_similarity: float, |
|
top_k: int, |
|
relation_types: List[str], |
|
columns: List[str], |
|
) -> pd.DataFrame: |
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
result = [] |
|
for (doc_id, annotation_id), score in similar_entries: |
|
|
|
if doc_id == ref_document.id: |
|
continue |
|
document = self.documents[doc_id] |
|
tail2rels = defaultdict(list) |
|
head2rels = defaultdict(list) |
|
for rel in document.binary_relations.predictions: |
|
|
|
if rel.label not in relation_types: |
|
continue |
|
head2rels[rel.head].append(rel) |
|
tail2rels[rel.tail].append(rel) |
|
|
|
id2annotation = { |
|
labeled_span_to_id(annotation): annotation |
|
for annotation in document.labeled_spans.predictions |
|
} |
|
annotation = id2annotation.get(annotation_id) |
|
|
|
|
|
for rel in head2rels.get(annotation, []): |
|
result.append( |
|
{ |
|
"doc_id": doc_id, |
|
"reference_adu": str(annotation), |
|
"sim_score": score, |
|
"rel_score": rel.score, |
|
"relation": rel.label, |
|
"adu": str(rel.tail), |
|
} |
|
) |
|
|
|
|
|
df = pd.DataFrame(result, columns=columns) |
|
return df |
|
|
|
def add_document( |
|
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
) -> None: |
|
try: |
|
if document.id in self.documents: |
|
gr.Warning(f"Document '{document.id}' already in index. Overwriting.") |
|
|
|
|
|
self.documents[document.id] = document |
|
|
|
|
|
for adu_id, embedding in document.metadata["embeddings"].items(): |
|
self.vector_store.save((document.id, adu_id), embedding) |
|
|
|
gr.Info( |
|
f"Added document {document.id} to index (index contains {len(self.documents)} " |
|
f"documents and {len(self.vector_store)} embeddings)." |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to add document {document.id} to index: {e}") |
|
|
|
def add_document_from_dict(self, document_dict: dict) -> None: |
|
document = self.DOCUMENT_TYPE.fromdict(document_dict) |
|
|
|
document.metadata = document_dict["metadata"] |
|
self.add_document(document) |
|
|
|
def add_documents( |
|
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions] |
|
) -> None: |
|
for document in documents: |
|
self.add_document(document) |
|
|
|
def get_document( |
|
self, doc_id: str |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
return self.documents[doc_id] |
|
|
|
def overview(self) -> pd.DataFrame: |
|
df = pd.DataFrame( |
|
[ |
|
( |
|
doc_id, |
|
len(document.labeled_spans.predictions), |
|
len(document.binary_relations.predictions), |
|
) |
|
for doc_id, document in self.documents.items() |
|
], |
|
columns=["doc_id", "num_adus", "num_relations"], |
|
) |
|
return df |
|
|
|
def as_dict(self) -> dict: |
|
return {doc_id: document.asdict() for doc_id, document in self.documents.items()} |
|
|