|
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification |
|
import tensorflow as tf |
|
import gradio as gr |
|
|
|
|
|
model_name = "Zabihin/Symptom_to_Diagnosis" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = TFAutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
def clean_input(symptom_text): |
|
|
|
symptom_text = ''.join([c for c in symptom_text if ord(c) < 128]) |
|
symptom_text = symptom_text.lower() |
|
return symptom_text |
|
|
|
|
|
def predict(symptom_text, chat_history=[]): |
|
try: |
|
|
|
symptom_text = clean_input(symptom_text) |
|
|
|
|
|
inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
prediction = tf.argmax(logits, axis=-1).numpy()[0] |
|
|
|
|
|
labels = { |
|
0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis", |
|
4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction", |
|
9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension", |
|
12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease", |
|
17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins" |
|
} |
|
|
|
diagnosis = labels.get(prediction, "Unknown diagnosis") |
|
|
|
|
|
chat_history.append(("User", symptom_text)) |
|
chat_history.append(("AI", f"Predicted Diagnosis: {diagnosis}. Please consult a doctor for more accurate results.")) |
|
|
|
except Exception as e: |
|
chat_history.append(("AI", f"Error: {str(e)}")) |
|
|
|
return chat_history, "" |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("<h1 style='text-align: center; margin-top: 20px; margin-bottom: 20px; font-size: 36px;'>Medi Mind - Your AI Health Assistant</h1>") |
|
chatbot = gr.Chatbot() |
|
input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...") |
|
send_button = gr.Button("Send") |
|
|
|
input_box.submit(predict, [input_box, chatbot], [chatbot, input_box]) |
|
send_button.click(predict, [input_box, chatbot], [chatbot, input_box]) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|