File size: 5,516 Bytes
86277c0
1681237
86277c0
 
1681237
86277c0
1681237
fed112f
86277c0
 
 
 
 
 
 
 
1681237
86277c0
1681237
 
86277c0
1681237
 
 
 
 
 
 
 
86277c0
 
1681237
86277c0
1681237
 
86277c0
1681237
86277c0
 
1681237
 
86277c0
1681237
86277c0
 
 
 
 
 
 
1681237
fed112f
86277c0
 
1681237
86277c0
 
 
 
fed112f
86277c0
 
 
 
 
1681237
 
 
fed112f
 
 
 
 
 
 
 
1681237
86277c0
 
1681237
 
 
 
 
86277c0
1681237
 
 
 
 
 
 
 
 
 
86277c0
 
1681237
86277c0
 
 
 
fed112f
 
 
86277c0
 
fed112f
86277c0
 
 
fed112f
1681237
 
 
 
 
fed112f
86277c0
1681237
 
 
 
 
 
 
 
fed112f
1681237
 
fed112f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1681237
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
from typing import Optional, Tuple

import gradio as gr
import torch
from annotation_utils import labeled_span_to_id
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
from pie_modules.document.processing import RegexPartitioner
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions

logger = logging.getLogger(__name__)


def annotate_document(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    annotation_pipeline: Pipeline,
    embedding_model: Optional[EmbeddingModel] = None,
) -> None:
    """Annotate a document with the provided pipeline. If an embedding model is provided, also
    extract embeddings for the labeled spans.

    Args:
        document: The document to annotate.
        annotation_pipeline: The pipeline to use for annotation.
        embedding_model: The embedding model to use for extracting text span embeddings.
    """

    # execute prediction pipeline
    annotation_pipeline(document)

    if embedding_model is not None:
        text_span_embeddings = embedding_model(
            document=document,
            span_layer_name="labeled_spans",
        )
        # convert keys to str because JSON keys must be strings
        text_span_embeddings_dict = {
            labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
        }
        document.metadata["embeddings"] = text_span_embeddings_dict
    else:
        gr.Warning(
            "No embedding model provided. Skipping embedding extraction. You can load an embedding "
            "model in the 'Model Configuration' section."
        )


def create_document(
    text: str, doc_id: str, split_regex: Optional[str] = None
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
    """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
    text.

    Parameters:
        text: The text to process.
        doc_id: The ID of the document.
        split_regex: A regular expression pattern to use for splitting the text into partitions.

    Returns:
        The processed document.
    """

    document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
        id=doc_id, text=text, metadata={}
    )
    if split_regex is not None:
        partitioner = RegexPartitioner(
            pattern=split_regex, partition_layer_name="labeled_partitions"
        )
        document = partitioner(document)
    else:
        # add single partition from the whole text (the model only considers text in partitions)
        document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
    return document


def load_argumentation_model(
    model_name: str,
    revision: Optional[str] = None,
    device: str = "cpu",
) -> Pipeline:
    try:
        # the Pipeline class expects an integer for the device
        if device == "cuda":
            pipeline_device = 0
        elif device.startswith("cuda:"):
            pipeline_device = int(device.split(":")[1])
        elif device == "cpu":
            pipeline_device = -1
        else:
            raise gr.Error(f"Invalid device: {device}")

        model = AutoPipeline.from_pretrained(
            model_name,
            device=pipeline_device,
            num_workers=0,
            taskmodule_kwargs=dict(revision=revision),
            model_kwargs=dict(revision=revision),
        )
        gr.Info(
            f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
        )
    except Exception as e:
        raise gr.Error(f"Failed to load argumentation model: {e}")

    return model


def load_embedding_model(
    embedding_model_name: Optional[str] = None,
    # embedding_model_revision: Optional[str] = None,
    embedding_max_length: int = 512,
    embedding_batch_size: int = 16,
    device: str = "cpu",
) -> Optional[EmbeddingModel]:
    if embedding_model_name is not None and embedding_model_name.strip():
        try:
            embedding_model = HuggingfaceEmbeddingModel(
                embedding_model_name.strip(),
                # revision=embedding_model_revision,
                device=device,
                max_length=embedding_max_length,
                batch_size=embedding_batch_size,
            )
            gr.Info(f"Loaded embedding model: model_name={embedding_model_name}, device={device}")
        except Exception as e:
            raise gr.Error(f"Failed to load embedding model: {e}")
    else:
        embedding_model = None

    return embedding_model


def load_models(
    model_name: str,
    revision: Optional[str] = None,
    embedding_model_name: Optional[str] = None,
    # embedding_model_revision: Optional[str] = None,
    embedding_max_length: int = 512,
    embedding_batch_size: int = 16,
    device: str = "cpu",
) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
    torch.cuda.empty_cache()
    argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
    embedding_model = load_embedding_model(
        embedding_model_name=embedding_model_name,
        # embedding_model_revision=embedding_model_revision,
        embedding_max_length=embedding_max_length,
        embedding_batch_size=embedding_batch_size,
        device=device,
    )

    return argumentation_model, embedding_model