File size: 1,112 Bytes
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import streamlit as st
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForTokenClassification  # type: ignore
from transformers import AutoTokenizer  # type: ignore


@st.experimental_singleton()
def get_model(model_name: str, labels=None):
    if labels is None:
        return AutoModelForTokenClassification.from_pretrained(
            model_name,
            output_attentions=True,
        )  # type: ignore
    else:
        id2label = {idx: tag for idx, tag in enumerate(labels)}
        label2id = {tag: idx for idx, tag in enumerate(labels)}
        return AutoModelForTokenClassification.from_pretrained(
            model_name,
            output_attentions=True,
            num_labels=len(labels),
            id2label=id2label,
            label2id=label2id,
        )  # type: ignore


@st.experimental_singleton()
def get_encoder(model_name: str, device: str = "cpu"):
    return SentenceTransformer(model_name, device=device)


@st.experimental_singleton()
def get_tokenizer(tokenizer_name: str):
    return AutoTokenizer.from_pretrained(tokenizer_name)