File size: 1,112 Bytes
2d4811a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import streamlit as st
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForTokenClassification # type: ignore
from transformers import AutoTokenizer # type: ignore
@st.experimental_singleton()
def get_model(model_name: str, labels=None):
if labels is None:
return AutoModelForTokenClassification.from_pretrained(
model_name,
output_attentions=True,
) # type: ignore
else:
id2label = {idx: tag for idx, tag in enumerate(labels)}
label2id = {tag: idx for idx, tag in enumerate(labels)}
return AutoModelForTokenClassification.from_pretrained(
model_name,
output_attentions=True,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
) # type: ignore
@st.experimental_singleton()
def get_encoder(model_name: str, device: str = "cpu"):
return SentenceTransformer(model_name, device=device)
@st.experimental_singleton()
def get_tokenizer(tokenizer_name: str):
return AutoTokenizer.from_pretrained(tokenizer_name)
|