|
import json |
|
import logging |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from annotation_utils import labeled_span_to_id |
|
from pytorch_ie import Annotation |
|
from pytorch_ie.documents import ( |
|
TextBasedDocument, |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
from vector_store import SimpleVectorStore, VectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_annotation_from_document( |
|
document: TextBasedDocument, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
use_predictions: bool, |
|
) -> Annotation: |
|
"""Get an annotation from a document by its id. Note that the annotation id is constructed from |
|
the annotation itself, so it is unique within the document. |
|
|
|
Args: |
|
document: The document to get the annotation from. |
|
annotation_id: The id of the annotation. |
|
annotation_layer: The name of the annotation layer. |
|
use_predictions: Whether to use the predictions of the annotation layer. |
|
|
|
Returns: |
|
The annotation with the given id. |
|
""" |
|
|
|
annotations = document[annotation_layer] |
|
if use_predictions: |
|
annotations = annotations.predictions |
|
|
|
if annotation_layer == "labeled_spans": |
|
annotation_to_id_func = labeled_span_to_id |
|
else: |
|
raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.") |
|
|
|
id2annotation = {annotation_to_id_func(annotation): annotation for annotation in annotations} |
|
annotation = id2annotation.get(annotation_id) |
|
if annotation is None: |
|
raise gr.Error( |
|
f"annotation '{annotation_id}' not found in document '{document.id}'. Available " |
|
f"annotations: {id2annotation}" |
|
) |
|
return annotation |
|
|
|
|
|
def get_related_annotation_records_from_document( |
|
document: TextBasedDocument, |
|
reference_annotation: Annotation, |
|
relation_layer_name: str, |
|
use_predictions: bool, |
|
annotation_caption: str, |
|
relation_types: Optional[List[str]] = None, |
|
additional_static_columns: Optional[Dict[str, str]] = None, |
|
) -> List[Dict[str, str]]: |
|
"""Get related annotations from a document for a given reference annotation. The related |
|
annotations are all annotations that are targets (tails) of relations with the reference |
|
annotation as source (head). |
|
|
|
Args: |
|
document: The document to get the related annotations from. |
|
reference_annotation: The reference annotation. Should be an annotation from the document. |
|
relation_layer_name: The name of the relation layer. |
|
use_predictions: Whether to use the predictions of the relation layer. |
|
annotation_caption: The caption for the related annotations in the result. |
|
relation_types: The types of relations to consider. If None, all relation types are considered. |
|
additional_static_columns: Additional static columns to add to the result. |
|
|
|
Returns: |
|
A list of dictionaries with the related annotations and additional columns. |
|
""" |
|
|
|
result = [] |
|
|
|
|
|
relation_layer = document[relation_layer_name] |
|
if use_predictions: |
|
relation_layer = relation_layer.predictions |
|
|
|
|
|
tail2rels = defaultdict(list) |
|
head2rels = defaultdict(list) |
|
for rel in relation_layer: |
|
|
|
if relation_types is not None and rel.label not in relation_types: |
|
continue |
|
head2rels[rel.head].append(rel) |
|
tail2rels[rel.tail].append(rel) |
|
|
|
|
|
|
|
for rel in head2rels.get(reference_annotation, []): |
|
result.append( |
|
{ |
|
"doc_id": document.id, |
|
f"reference_{annotation_caption}": str(reference_annotation), |
|
"rel_score": rel.score, |
|
"relation": rel.label, |
|
annotation_caption: str(rel.tail), |
|
**(additional_static_columns or {}), |
|
} |
|
) |
|
return result |
|
|
|
|
|
class DocumentStore: |
|
"""A document store that allows to add, retrieve, and search for documents and annotations. |
|
|
|
The store keeps the documents in memory and stores the embeddings of the labeled spans in a vector |
|
store to efficiently retrieve similar or related spans. |
|
|
|
Args: |
|
vector_store: The vector store to use. If None, a new SimpleVectorStore is created. |
|
document_type: The type of the documents to store. Should be a subclass of TextBasedDocument with |
|
a span and a relation layer (see below). |
|
span_layer_name: The name of the span annotation layer. This should be a valid annotation layer |
|
of type LabelSpan in the document type. |
|
relation_layer_name: The name of the argumentative relation annotation layer. This should be a |
|
valid annotation layer of type BinaryRelation in the document type. |
|
span_annotation_caption: The caption for the span annotations (e.g. in the statistical overview) |
|
relation_annotation_caption: The caption for the relation annotations (e.g. in the statistical |
|
overview) |
|
use_predictions: Whether to use the predictions of the annotation layers. If True, the predictions |
|
are used, otherwise the gold annotations are used. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None, |
|
document_type: type[ |
|
TextBasedDocument |
|
] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
span_layer_name: str = "labeled_spans", |
|
relation_layer_name: str = "binary_relations", |
|
span_annotation_caption: str = "span", |
|
relation_annotation_caption: str = "relation", |
|
use_predictions: bool = True, |
|
): |
|
|
|
|
|
self.documents: Dict[str, TextBasedDocument] = {} |
|
|
|
|
|
self.vector_store: VectorStore[Tuple[str, str], List[float]] = ( |
|
vector_store or SimpleVectorStore() |
|
) |
|
|
|
self.document_type = document_type |
|
self.span_layer_name = span_layer_name |
|
self.relation_layer_name = relation_layer_name |
|
self.use_predictions = use_predictions |
|
self.layer_captions = { |
|
self.span_layer_name: span_annotation_caption, |
|
self.relation_layer_name: relation_annotation_caption, |
|
} |
|
|
|
def get_annotation( |
|
self, |
|
doc_id: str, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
use_predictions: bool, |
|
) -> Annotation: |
|
document = self.documents.get(doc_id) |
|
if document is None: |
|
raise gr.Error( |
|
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" |
|
) |
|
return get_annotation_from_document( |
|
document, annotation_id, annotation_layer, use_predictions=use_predictions |
|
) |
|
|
|
def get_similar_annotations_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextBasedDocument, |
|
annotation_layer: str, |
|
**similarity_kwargs, |
|
) -> pd.DataFrame: |
|
"""Get similar annotations from documents in the store sorted by similarity. Usually, the |
|
reference annotation is returned as the most similar annotation. |
|
|
|
Args: |
|
ref_annotation_id: The id of the reference annotation. |
|
ref_document: The document of the reference annotation. |
|
annotation_layer: The name of the annotation layer to consider. |
|
**similarity_kwargs: Additional keyword arguments that will be passed to the vector |
|
store to retrieve similar entries (see VectorStore.retrieve_similar()). |
|
|
|
Returns: |
|
A DataFrame with the similar annotations with columns: doc_id, annotation_id, sim_score, |
|
and text. |
|
""" |
|
|
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
**similarity_kwargs, |
|
) |
|
|
|
similar_annotations = [ |
|
self.get_annotation( |
|
doc_id=doc_id, |
|
annotation_id=annotation_id, |
|
annotation_layer=annotation_layer, |
|
use_predictions=self.use_predictions, |
|
) |
|
for (doc_id, annotation_id), _ in similar_entries |
|
] |
|
df = pd.DataFrame( |
|
[ |
|
|
|
|
|
(doc_id, annotation_id, score, str(annotation)) |
|
for ((doc_id, annotation_id), score), annotation in zip( |
|
similar_entries, similar_annotations |
|
) |
|
], |
|
columns=["doc_id", "annotation_id", "sim_score", "text"], |
|
) |
|
|
|
return df |
|
|
|
def get_related_annotations_from_other_documents_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextBasedDocument, |
|
min_similarity: float, |
|
top_k: int, |
|
relation_types: List[str], |
|
columns: List[str], |
|
) -> pd.DataFrame: |
|
"""Get related annotations from documents in the store for a given reference annotation. |
|
First, similar annotations are retrieved from the vector store. Then, annotations that are |
|
linked to them via relations are returned. Only annotations from other documents are |
|
considered. |
|
|
|
Args: |
|
ref_annotation_id: The id of the reference annotation. |
|
ref_document: The document of the reference annotation. |
|
min_similarity: The minimum similarity score to consider. |
|
top_k: The number of related annotations to return. |
|
relation_types: The types of relations to consider. |
|
columns: The columns to include in the result DataFrame. |
|
|
|
Returns: |
|
A DataFrame with the columns that contain: the related annotation, the relation type, |
|
the similar annotation, the similarity score, and the relation score. |
|
""" |
|
|
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
result = [] |
|
for (doc_id, annotation_id), score in similar_entries: |
|
|
|
if doc_id == ref_document.id: |
|
continue |
|
document = self.documents[doc_id] |
|
reference_annotation = get_annotation_from_document( |
|
document=document, |
|
annotation_id=annotation_id, |
|
annotation_layer=self.span_layer_name, |
|
use_predictions=self.use_predictions, |
|
) |
|
|
|
new_entries = get_related_annotation_records_from_document( |
|
document=document, |
|
reference_annotation=reference_annotation, |
|
relation_types=relation_types, |
|
relation_layer_name=self.relation_layer_name, |
|
use_predictions=self.use_predictions, |
|
annotation_caption=self.layer_captions[self.span_layer_name], |
|
additional_static_columns={"sim_score": str(score)}, |
|
) |
|
result.extend(new_entries) |
|
|
|
|
|
df = pd.DataFrame(result, columns=columns) |
|
return df |
|
|
|
def add_document(self, document: TextBasedDocument) -> None: |
|
try: |
|
if document.id in self.documents: |
|
gr.Warning(f"Document '{document.id}' already in index. Overwriting.") |
|
|
|
|
|
self.documents[document.id] = document |
|
|
|
|
|
for annotation_id, embedding in document.metadata["embeddings"].items(): |
|
self.vector_store.save((document.id, annotation_id), embedding) |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Failed to add document {document.id} to index: {e}") |
|
|
|
def add_document_from_dict(self, document_dict: dict) -> None: |
|
document = self.document_type.fromdict(document_dict) |
|
self.add_document(document) |
|
|
|
def add_documents(self, documents: List[TextBasedDocument]) -> None: |
|
for document in documents: |
|
self.add_document(document) |
|
gr.Info( |
|
f"Added {len(documents)} documents to the index ({len(self.documents)} documents in total)." |
|
) |
|
|
|
def add_documents_from_json(self, file_path: str) -> None: |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
documents_json = json.load(f) |
|
for _, document_json in documents_json.items(): |
|
self.add_document_from_dict(document_dict=document_json) |
|
gr.Info( |
|
f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)." |
|
) |
|
|
|
def save_to_json(self, file_path: str, **kwargs) -> None: |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
json.dump(self.as_dict(), f, **kwargs) |
|
|
|
def get_document(self, doc_id: str) -> TextBasedDocument: |
|
return self.documents[doc_id] |
|
|
|
def overview(self) -> pd.DataFrame: |
|
rows = [] |
|
for doc_id, document in self.documents.items(): |
|
layers = { |
|
caption: document[layer_name] |
|
for layer_name, caption in self.layer_captions.items() |
|
} |
|
if self.use_predictions: |
|
layers = {caption: layer.predictions for caption, layer in layers.items()} |
|
layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()} |
|
rows.append({"doc_id": doc_id, **layer_sizes}) |
|
df = pd.DataFrame(rows) |
|
return df |
|
|
|
def as_dict(self) -> dict: |
|
return {doc_id: document.asdict() for doc_id, document in self.documents.items()} |
|
|