atereoyinn
added cudnn
6787f2a
raw
history blame
No virus
7.84 kB
import gradio as gr
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import csv
import json
import tempfile
import torch
# Initiate checkpoints for model loading
numind_checkpoint = "numind/NuExtract-tiny"
llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction"
whisper_checkpoint = "large-v3"
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
# Load models with the correct device
whisper_model = WhisperModel(whisper_checkpoint, device="cuda")
numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True)
numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint)
llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True)
llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint)
# Function to transcribe audio
def transcribe_audio(audio_file_path):
try:
segments, info = whisper_model.transcribe(audio_file_path, beam_size=5)
text = "".join([segment.text for segment in segments])
return text
except Exception as e:
return str(e)
# Functions for Person entity extraction
def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]):
schema = json.dumps(json.loads(schema), indent=4)
input_llm = "<|input|>\n### Template:\n" + schema + "\n"
for i in example:
if i != "":
input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n"
input_llm += "### Text:\n"+text +"\n<|output|>\n"
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda")
output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True)
return output.split("<|output|>")[1].split("<|end-output|>")[0]
#Function for generating promtps for Llama
def prompt_format(text):
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = """Extract the following entities from the medical conversation:
* **Symptoms:** List all the symptoms the patient mentions.
* **Diagnosis:** List the doctor's diagnosis or potential diagnoses.
* **Medical History:** Summarize the patient's relevant medical history.
* **Action Plan:** List the recommended actions or treatment plan.
Provide the result in the following JSON format:
{
"Symptoms": [...],
"Diagnosis": [...],
"Medical history": [...],
"Action plan": [...]
}"""
full_prompt = prompt.format(instruction, text, "")
return full_prompt
#Pydantic Validator to validate Llama's response
def validate_medical_record(response):
class MedicalRecord(BaseModel):
Symptoms: List[str] = Field(default_factory=list)
Diagnosis: List[str] = Field(default_factory=list)
Medical_history: List[str] = Field(
default_factory=list,
validation_alias=AliasChoices('Medical history', 'History of Patient')
)
Action_plan: List[str] = Field(
default_factory=list,
validation_alias=AliasChoices('Action plan', 'Plan of Action')
)
@field_validator('*', mode='before')
def ensure_list(cls, v):
if isinstance(v, str):
return [item.strip() for item in v.split(',')]
return v
try:
validated_data = MedicalRecord(**response)
return validated_data.dict()
except ValidationError as e:
return response
# Function to predict medical entities using Llama
def predict_Llama(model, tokenizer, text):
inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda")
try:
outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True)
extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = extracted_entities.split("### Response:", 1)[-1].strip()
response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)}
validated_response = validate_medical_record(response_dict)
return validated_response
except Exception as e:
print(f"Error during Llama prediction: {str(e)}")
return {}
#Control function that cordinates communication of other functions to map entities to form fields
def process_audio(audio):
if isinstance(audio, str):
with open(audio, 'rb') as f:
audio_bytes = f.read()
else:
audio_bytes = audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
temp_audio.flush()
audio_path = temp_audio.name
transcription = transcribe_audio(audio_path)
person_schema = """{"Name": "","Age": "","Gender": ""}"""
person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema)
try:
person_entities = json.loads(person_entities_raw)
except json.JSONDecodeError as e:
return f"Error in NuExtract response: {str(e)}"
medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription)
return (
person_entities.get("Name", ""),
person_entities.get("Age", ""),
person_entities.get("Gender", ""),
", ".join(medical_entities.get("Symptoms", [])),
", ".join(medical_entities.get("Diagnosis", [])),
", ".join(medical_entities.get("Medical_history", [])),
", ".join(medical_entities.get("Action_plan", []))
)
#Function that allows users to download information
def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan):
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
with open(csv_file.name, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"])
writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan])
return csv_file.name
# Gradio interface to create a web-based form for users to input audio and fill the medical diagnostic form
demo = gr.Interface(
fn=process_audio,
inputs=[
gr.Audio(type="filepath")
],
outputs=[
gr.Textbox(label="Name"),
gr.Textbox(label="Age"),
gr.Textbox(label="Gender"),
gr.Textbox(label="Symptoms"),
gr.Textbox(label="Diagnosis"),
gr.Textbox(label="Medical History"),
gr.Textbox(label="Plan of Action"),
],
title="Medical Diagnostic Form Assistant",
description="Upload an audio file or record audio to generate a medical diagnostic form."
)
with demo:
download_button = gr.Button("Download CSV")
download_button.click(
fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan),
inputs=demo.output_components,
outputs=gr.File(label="Download CSV")
)
demo.launch()