Tokymin's picture
更换path
895be0d
raw
history blame
2.19 kB
from pathlib import Path
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
from transformers import AutoTokenizer, AutoModel
import requests
# Assuming you have set the HF_TOKEN environment variable with your Hugging Face token
huggingface_token = os.getenv('HF_TOKEN')
# Set up the token to use with the Hugging Face API
if huggingface_token is not None:
os.environ['HUGGINGFACE_CO_API_TOKEN'] = huggingface_token
API_URL = "https://api-inference.huggingface.co/models/Tokymin/Mood_Anxiety_Disorder_Classify_Model"
headers = {"Authorization": f"Tokymin {huggingface_token}"}
else:
print("error, no token")
exit(0)
# def query(payload):
# response = requests.post(API_URL, headers=headers, json=payload)
# return response.json()
# data = query("Can you please let us know more details about your ")
path: Path = Path('Tokymin/Mood_Anxiety_Disorder_Classify_Model')
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=path, cache_dir='/home/user', token=huggingface_token)
# tokenizer = AutoTokenizer.from_pretrained('Tokymin/Mood_Anxiety_Disorder_Classify_Model')
model = AutoModelForSequenceClassification.from_pretrained("Tokymin/Mood_Anxiety_Disorder_Classify_Model",num_labels=8)
model.eval()
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).squeeze()
# 假设每个类别(SAS_Class和SDS_Class)都有4个概率值
sas_probs = probabilities[:4] # 获取SAS_Class的概率
sds_probs = probabilities[4:] # 获取SDS_Class的概率
return sas_probs, sds_probs
# 创建Streamlit应用
st.title("Multi-label Classification App")
# 用户输入文本
user_input = st.text_area("Enter text here", "Type something...")
if st.button("Predict"):
# 显示预测结果
sas_probs, sds_probs = predict(user_input)
st.write("SAS_Class probabilities:", sas_probs.numpy())
st.write("SDS_Class probabilities:", sds_probs.numpy())