|
from pathlib import Path |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import os |
|
from transformers import AutoTokenizer, AutoModel |
|
import requests |
|
|
|
|
|
huggingface_token = os.getenv('HF_TOKEN') |
|
|
|
if huggingface_token is not None: |
|
os.environ['HUGGINGFACE_CO_API_TOKEN'] = huggingface_token |
|
API_URL = "https://api-inference.huggingface.co/models/Tokymin/Mood_Anxiety_Disorder_Classify_Model" |
|
headers = {"Authorization": f"Tokymin {huggingface_token}"} |
|
else: |
|
print("error, no token") |
|
exit(0) |
|
|
|
|
|
|
|
|
|
|
|
path: Path = Path('Tokymin/Mood_Anxiety_Disorder_Classify_Model') |
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=path, cache_dir='/home/user', token=huggingface_token) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("Tokymin/Mood_Anxiety_Disorder_Classify_Model",num_labels=8) |
|
model.eval() |
|
|
|
|
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=1).squeeze() |
|
|
|
sas_probs = probabilities[:4] |
|
sds_probs = probabilities[4:] |
|
return sas_probs, sds_probs |
|
|
|
|
|
|
|
st.title("Multi-label Classification App") |
|
|
|
|
|
user_input = st.text_area("Enter text here", "Type something...") |
|
|
|
if st.button("Predict"): |
|
|
|
sas_probs, sds_probs = predict(user_input) |
|
st.write("SAS_Class probabilities:", sas_probs.numpy()) |
|
st.write("SDS_Class probabilities:", sds_probs.numpy()) |
|
|