crocus-medical / app.py
caesarCITREA's picture
Add application file
e3bd131
import streamlit as st
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('caesarCITREA/crocus-bert-medical-department-classification')
model = BertForSequenceClassification.from_pretrained('caesarCITREA/crocus-bert-medical-department-classification')
# Define the department names
departments = [
"Kadın Hastalıkları ve Doğum",
"Ortopedi ve Travmatoloji" ,
"Dermatoloji",
"Göğüs Hastalıkları ",
"Nöroloji",
"Onkoloji" ,
"Dahiliye (İç Hastalıkları)" ,
"Kardiyoloji",
"Psikiyatri" ,
"Pediatri" ,
"Nefroloji" ,
"Fiziksel Tıp ve Rehabilitasyon" ,
"Enfeksiyon Hastalıkları ve Klinik Mikrobiyoloji" ,
"Üroloji" ,
"Kulak Burun Boğaz (KBB)",
"Göz Hastalıkları"
]
# Function to predict the department
def predict_department(description):
# Tokenize input
inputs = tokenizer(description, return_tensors="pt", truncation=True, padding=True)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the department with the highest score
predicted_class = torch.argmax(logits, dim=1).item()
# Return the department name
return departments[predicted_class]
# Streamlit app interface
st.title("Medical Department Classifier")
# Input text box for the user to describe the symptoms
description = st.text_area("Lütfen yaşadığınız tıbbi şikayetleri giriniz:")
# Button to classify the input
if st.button("Classify"):
if description:
department = predict_department(description)
st.write(f"Gitmeniz gereken tıbbi departman: **{department}**")
else:
st.write("Lütfen yaşadığınız durumu açıklanıyınız.")