import gradio as gr from faster_whisper import WhisperModel from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError from typing import List from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import csv import json import tempfile import torch # Initiate checkpoints for model loading numind_checkpoint = "numind/NuExtract-tiny" llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction" whisper_checkpoint = "large-v3" quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) # Load models with the correct device whisper_model = WhisperModel(whisper_checkpoint, device="cuda") numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True) numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint) llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True) llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint) # Function to transcribe audio def transcribe_audio(audio_file_path): try: segments, info = whisper_model.transcribe(audio_file_path, beam_size=5) text = "".join([segment.text for segment in segments]) return text except Exception as e: return str(e) # Functions for Person entity extraction def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]): schema = json.dumps(json.loads(schema), indent=4) input_llm = "<|input|>\n### Template:\n" + schema + "\n" for i in example: if i != "": input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n" input_llm += "### Text:\n"+text +"\n<|output|>\n" input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda") output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True) return output.split("<|output|>")[1].split("<|end-output|>")[0] #Function for generating promtps for Llama def prompt_format(text): prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {} ### Input: {} ### Response: {}""" instruction = """Extract the following entities from the medical conversation: * **Symptoms:** List all the symptoms the patient mentions. * **Diagnosis:** List the doctor's diagnosis or potential diagnoses. * **Medical History:** Summarize the patient's relevant medical history. * **Action Plan:** List the recommended actions or treatment plan. Provide the result in the following JSON format: { "Symptoms": [...], "Diagnosis": [...], "Medical history": [...], "Action plan": [...] }""" full_prompt = prompt.format(instruction, text, "") return full_prompt #Pydantic Validator to validate Llama's response def validate_medical_record(response): class MedicalRecord(BaseModel): Symptoms: List[str] = Field(default_factory=list) Diagnosis: List[str] = Field(default_factory=list) Medical_history: List[str] = Field( default_factory=list, validation_alias=AliasChoices('Medical history', 'History of Patient') ) Action_plan: List[str] = Field( default_factory=list, validation_alias=AliasChoices('Action plan', 'Plan of Action') ) @field_validator('*', mode='before') def ensure_list(cls, v): if isinstance(v, str): return [item.strip() for item in v.split(',')] return v try: validated_data = MedicalRecord(**response) return validated_data.dict() except ValidationError as e: return response # Function to predict medical entities using Llama def predict_Llama(model, tokenizer, text): inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda") try: outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True) extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True) response = extracted_entities.split("### Response:", 1)[-1].strip() response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)} validated_response = validate_medical_record(response_dict) return validated_response except Exception as e: print(f"Error during Llama prediction: {str(e)}") return {} #Control function that cordinates communication of other functions to map entities to form fields def process_audio(audio): if isinstance(audio, str): with open(audio, 'rb') as f: audio_bytes = f.read() else: audio_bytes = audio with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: temp_audio.write(audio_bytes) temp_audio.flush() audio_path = temp_audio.name transcription = transcribe_audio(audio_path) person_schema = """{"Name": "","Age": "","Gender": ""}""" person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema) try: person_entities = json.loads(person_entities_raw) except json.JSONDecodeError as e: return f"Error in NuExtract response: {str(e)}" medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription) return ( person_entities.get("Name", ""), person_entities.get("Age", ""), person_entities.get("Gender", ""), ", ".join(medical_entities.get("Symptoms", [])), ", ".join(medical_entities.get("Diagnosis", [])), ", ".join(medical_entities.get("Medical_history", [])), ", ".join(medical_entities.get("Action_plan", [])) ) #Function that allows users to download information def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan): csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") with open(csv_file.name, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"]) writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan]) return csv_file.name # Gradio interface to create a web-based form for users to input audio and fill the medical diagnostic form demo = gr.Interface( fn=process_audio, inputs=[ gr.Audio(type="filepath") ], outputs=[ gr.Textbox(label="Name"), gr.Textbox(label="Age"), gr.Textbox(label="Gender"), gr.Textbox(label="Symptoms"), gr.Textbox(label="Diagnosis"), gr.Textbox(label="Medical History"), gr.Textbox(label="Plan of Action"), ], title="Medical Diagnostic Form Assistant", description="Upload an audio file or record audio to generate a medical diagnostic form." ) with demo: download_button = gr.Button("Download CSV") download_button.click( fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan), inputs=demo.output_components, outputs=gr.File(label="Download CSV") ) demo.launch()