sam-pointer-bart-base-v0.3 / model_utils.py
ArneBinder's picture
Upload 10 files
1681237 verified
import logging
from typing import Optional, Tuple
import gradio as gr
import torch
from annotation_utils import labeled_span_to_id
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
logger = logging.getLogger(__name__)
def annotate_document(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
annotation_pipeline: Pipeline,
embedding_model: Optional[EmbeddingModel] = None,
) -> None:
"""Annotate a document with the provided pipeline. If an embedding model is provided, also
extract embeddings for the labeled spans.
Args:
document: The document to annotate.
annotation_pipeline: The pipeline to use for annotation.
embedding_model: The embedding model to use for extracting text span embeddings.
"""
# execute prediction pipeline
annotation_pipeline(document)
if embedding_model is not None:
text_span_embeddings = embedding_model(
document=document,
span_layer_name="labeled_spans",
)
# convert keys to str because JSON keys must be strings
text_span_embeddings_dict = {
labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
}
document.metadata["embeddings"] = text_span_embeddings_dict
else:
gr.Warning(
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
"model in the 'Model Configuration' section."
)
def create_document(
text: str, doc_id: str
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
text.
Parameters:
text: The text to process.
doc_id: The ID of the document.
Returns:
The processed document.
"""
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
id=doc_id, text=text, metadata={}
)
# add single partition from the whole text (the model only considers text in partitions)
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
return document
def load_argumentation_model(
model_name: str,
revision: Optional[str] = None,
device: str = "cpu",
) -> Pipeline:
try:
# the Pipeline class expects an integer for the device
if device == "cuda":
pipeline_device = 0
elif device.startswith("cuda:"):
pipeline_device = int(device.split(":")[1])
elif device == "cpu":
pipeline_device = -1
else:
raise gr.Error(f"Invalid device: {device}")
model = AutoPipeline.from_pretrained(
model_name,
device=pipeline_device,
num_workers=0,
taskmodule_kwargs=dict(revision=revision),
model_kwargs=dict(revision=revision),
)
except Exception as e:
raise gr.Error(f"Failed to load argumentation model: {e}")
gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
return model
def load_models(
model_name: str,
revision: Optional[str] = None,
embedding_model_name: Optional[str] = None,
# embedding_model_revision: Optional[str] = None,
embedding_max_length: int = 512,
embedding_batch_size: int = 16,
device: str = "cpu",
) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
torch.cuda.empty_cache()
argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
embedding_model = None
if embedding_model_name is not None and embedding_model_name.strip():
try:
embedding_model = HuggingfaceEmbeddingModel(
embedding_model_name.strip(),
# revision=embedding_model_revision,
device=device,
max_length=embedding_max_length,
batch_size=embedding_batch_size,
)
except Exception as e:
raise gr.Error(f"Failed to load embedding model: {e}")
return argumentation_model, embedding_model