ArneBinder's picture
Upload 7 files
04ce9af verified
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import gradio as gr
import pandas as pd
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils import labeled_span_to_id
from transformers import PreTrainedModel, PreTrainedTokenizer
from vector_store import SimpleVectorStore, VectorStore
logger = logging.getLogger(__name__)
def _embed_text_annotations(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
text_layer_name: str,
) -> Dict[Span, List[float]]:
# to not modify the original document
document = document.copy()
# tokenize_document does not yet consider predictions, so we need to add them manually
document[text_layer_name].extend(document[text_layer_name].predictions.clear())
added_annotations = []
tokenizer_kwargs = {
"max_length": 512,
"stride": 64,
"truncation": True,
"return_overflowing_tokens": True,
}
tokenized_documents = tokenize_document(
document,
tokenizer=tokenizer,
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
partition_layer="labeled_partitions",
added_annotations=added_annotations,
strict_span_conversion=False,
**tokenizer_kwargs,
)
# just tokenize again to get tensors in the correct format for the model
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
# this is added when using return_overflowing_tokens=True, but the model does not accept it
model_inputs.pop("overflow_to_sample_mapping", None)
assert len(model_inputs.encodings) == len(tokenized_documents)
model_output = model(**model_inputs)
# get embeddings for all text annotations
embeddings = {}
for batch_idx in range(len(model_output.last_hidden_state)):
text2tok_ann = added_annotations[batch_idx][text_layer_name]
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
for tok_ann in tokenized_documents[batch_idx].labeled_spans:
# skip "empty" annotations
if tok_ann.start == tok_ann.end:
continue
# use the max pooling strategy to get a single embedding for the annotation text
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
dim=0
)[0]
text_ann = tok2text_ann[tok_ann]
if text_ann in embeddings:
logger.warning(
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
)
embeddings[text_ann] = embedding
return embeddings
def _annotate(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
pipeline: Pipeline,
embedding_model: Optional[PreTrainedModel] = None,
embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
) -> None:
# execute prediction pipeline
pipeline(document)
if embedding_model is not None and embedding_tokenizer is not None:
adu_embeddings = _embed_text_annotations(
document=document,
model=embedding_model,
tokenizer=embedding_tokenizer,
text_layer_name="labeled_spans",
)
# convert keys to str because JSON keys must be strings
adu_embeddings_dict = {
labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
}
document.metadata["embeddings"] = adu_embeddings_dict
else:
gr.Warning(
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
"model in the 'Model Configuration' section."
)
def create_and_annotate_document(
text: str,
doc_id: str,
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
text, annotate it, and add it to the index.
Parameters:
text: The text to process.
doc_id: The ID of the document.
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
Returns:
The processed document.
"""
try:
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
id=doc_id, text=text, metadata={}
)
# add single partition from the whole text (the model only considers text in partitions)
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
# annotate the document
_annotate(
document=document,
pipeline=models[0],
embedding_model=models[1],
embedding_tokenizer=models[2],
)
return document
except Exception as e:
raise gr.Error(f"Failed to process text: {e}")
def get_annotation_from_document(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
annotation_id: str,
annotation_layer: str,
) -> LabeledSpan:
# use predictions
annotations = document[annotation_layer].predictions
id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
annotation = id2annotation.get(annotation_id)
if annotation is None:
raise gr.Error(
f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
f"annotations: {id2annotation}"
)
return annotation
class DocumentStore:
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str]]] = None):
self.documents = {}
self.vector_store = vector_store or SimpleVectorStore()
def get_annotation(
self,
doc_id: str,
annotation_id: str,
annotation_layer: str,
) -> LabeledSpan:
document = self.documents.get(doc_id)
if document is None:
raise gr.Error(
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
)
return get_annotation_from_document(document, annotation_id, annotation_layer)
def get_similar_adus_df(
self,
ref_annotation_id: str,
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
min_similarity: float,
top_k: int,
) -> pd.DataFrame:
similar_entries = self.vector_store.retrieve_similar(
ref_id=(ref_document.id, ref_annotation_id),
min_similarity=min_similarity,
top_k=top_k,
)
similar_annotations = [
self.get_annotation(
doc_id=doc_id,
annotation_id=annotation_id,
annotation_layer="labeled_spans",
)
for (doc_id, annotation_id), _ in similar_entries
]
df = pd.DataFrame(
[
# unpack the tuple (doc_id, annotation_id) to separate columns
# and add the similarity score and the text of the annotation
(doc_id, annotation_id, score, str(annotation))
for ((doc_id, annotation_id), score), annotation in zip(
similar_entries, similar_annotations
)
],
columns=["doc_id", "adu_id", "sim_score", "text"],
)
return df
def get_relevant_adus_df(
self,
ref_annotation_id: str,
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
min_similarity: float,
top_k: int,
relation_types: List[str],
columns: List[str],
) -> pd.DataFrame:
similar_entries = self.vector_store.retrieve_similar(
ref_id=(ref_document.id, ref_annotation_id),
min_similarity=min_similarity,
top_k=top_k,
)
result = []
for (doc_id, annotation_id), score in similar_entries:
# skip entries from the same document
if doc_id == ref_document.id:
continue
document = self.documents[doc_id]
tail2rels = defaultdict(list)
head2rels = defaultdict(list)
for rel in document.binary_relations.predictions:
# skip non-argumentative relations
if rel.label not in relation_types:
continue
head2rels[rel.head].append(rel)
tail2rels[rel.tail].append(rel)
id2annotation = {
labeled_span_to_id(annotation): annotation
for annotation in document.labeled_spans.predictions
}
annotation = id2annotation.get(annotation_id)
# note: we do not need to check if the annotation is different from the reference annotation,
# because they come from different documents and we already skip entries from the same document
for rel in head2rels.get(annotation, []):
result.append(
{
"doc_id": doc_id,
"reference_adu": str(annotation),
"sim_score": score,
"rel_score": rel.score,
"relation": rel.label,
"adu": str(rel.tail),
}
)
# define column order
df = pd.DataFrame(result, columns=columns)
return df
def add_document(
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
) -> None:
try:
if document.id in self.documents:
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
# save the processed document to the index
self.documents[document.id] = document
# save the embeddings to the vector store
for adu_id, embedding in document.metadata["embeddings"].items():
self.vector_store.save((document.id, adu_id), embedding)
gr.Info(
f"Added document {document.id} to index (index contains {len(self.documents)} "
f"documents and {len(self.vector_store)} embeddings)."
)
except Exception as e:
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
def add_document_from_dict(self, document_dict: dict) -> None:
document = self.DOCUMENT_TYPE.fromdict(document_dict)
# metadata is not automatically deserialized, so we need to set it manually
document.metadata = document_dict["metadata"]
self.add_document(document)
def add_documents(
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
) -> None:
for document in documents:
self.add_document(document)
def get_document(
self, doc_id: str
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
return self.documents[doc_id]
def overview(self) -> pd.DataFrame:
df = pd.DataFrame(
[
(
doc_id,
len(document.labeled_spans.predictions),
len(document.binary_relations.predictions),
)
for doc_id, document in self.documents.items()
],
columns=["doc_id", "num_adus", "num_relations"],
)
return df
def as_dict(self) -> dict:
return {doc_id: document.asdict() for doc_id, document in self.documents.items()}