Medical-Interviewer / settings.py
ruslanmv's picture
First commit
e8173b9
import traceback
from datetime import datetime
from pathlib import Path
import os
import random
import string
import tempfile
import re
import io
import PyPDF2
import docx
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_JUSTIFY
from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
from knowledge_retrieval import setup_knowledge_retrieval, generate_report
# Initialize settings
n_of_questions = n_of_questions()
current_datetime = datetime.now()
human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
current_date = current_datetime.strftime("%Y-%m-%d")
# Initialize the model and retrieval chain
try:
llm = load_model(openai_api_key)
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
knowledge_base_connected = True
print("Successfully connected to the knowledge base.")
except Exception as e:
print(f"Error initializing the model or retrieval chain: {str(e)}")
knowledge_base_connected = False
print("Falling back to basic mode without knowledge base.")
question_count = 0
interview_history = []
last_audio_path = None # Variable to store the path of the last audio file
initial_audio_path = None # Variable to store the path of the initial audio file
language = None
def generate_random_string(length=5):
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
def respond(message, history, voice, selected_interviewer):
global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
if not isinstance(history, list):
history = []
if not history or not history[-1]:
history.append(["", ""])
# Extract the actual message text
if isinstance(message, list):
message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
question_count += 1
interview_history.append(f"Q{question_count}: {message}")
history_str = "\n".join(interview_history)
print("Starting interview", question_count)
try:
if knowledge_base_connected:
if question_count == 1:
# Capture the language from the first response
language = message.strip().lower()
# Reinitialize the interview chain with the new language
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
llm, language, selected_interviewer)
if question_count < n_of_questions:
result = interview_retrieval_chain.invoke({
"input": f"Based on the patient's statement: '{message}', what should be the next question?",
"history": history_str,
"question_number": question_count + 1,
"language": language
})
question = result.get("answer", f"Can you tell me more about that? (in {language})")
else:
result = generate_report(report_retrieval_chain, interview_history, language)
question = result
speech_file_path = None # Skip audio generation for the report
if question:
random_suffix = generate_random_string()
speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
convert_text_to_speech(question, speech_file_path, voice)
print(f"Question {question_count} saved as audio at {speech_file_path}")
# Remove the last audio file if it exists
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = speech_file_path
else:
speech_file_path = None # Skip audio generation for the report
else:
# Fallback mode without knowledge base
question = f"Can you elaborate on that? (in {language})"
if question_count < n_of_questions:
speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
convert_text_to_speech(question, speech_file_path, voice)
print(f"Question {question_count} saved as audio at {speech_file_path}")
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = speech_file_path
else:
speech_file_path = None
history[-1][1] = f"{question}"
# Remove the initial question audio file after the first user response
if initial_audio_path and os.path.exists(initial_audio_path):
os.remove(initial_audio_path)
initial_audio_path = None
# Clean up older files based on question_count
if question_count > 1:
previous_audio_path = Path(__file__).parent / f"question_{question_count-1}_{random_suffix}.mp3"
if os.path.exists(previous_audio_path):
os.remove(previous_audio_path)
return history, str(speech_file_path) if speech_file_path else None
except Exception as e:
print(f"Error in retrieval chain: {str(e)}")
print(traceback.format_exc())
return history, None
def reset_interview():
"""Reset the interview state."""
global question_count, interview_history, last_audio_path, initial_audio_path
question_count = 0
interview_history = []
if last_audio_path and os.path.exists(last_audio_path):
os.remove(last_audio_path)
last_audio_path = None
initial_audio_path = None
def read_file(file):
if file is None:
return "No file uploaded"
if isinstance(file, str):
with open(file, 'r', encoding='utf-8') as f:
return f.read()
if hasattr(file, 'name'): # Check if it's a file-like object
if file.name.endswith('.txt'):
return file.content
elif file.name.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
return "\n".join(page.extract_text() for page in pdf_reader.pages)
elif file.name.endswith('.docx'):
doc = docx.Document(io.BytesIO(file.content))
return "\n".join(paragraph.text for paragraph in doc.paragraphs)
else:
return "Unsupported file format"
return "Unable to read file"
def generate_report_from_file(file, language):
try:
file_content = read_file(file)
if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
return file_content
file_content = file_content[:100000]
report_language = language.strip().lower() if language else "english"
print('preferred language:', report_language)
print(f"Generating report in language: {report_language}") # For debugging
# Reinitialize the report chain with the new language
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
result = report_retrieval_chain.invoke({
"input": "Please provide a clinical report based on the following content:",
"history": file_content,
"language": report_language
})
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
pdf_path = create_pdf(report_content)
return report_content, pdf_path
except Exception as e:
return f"An error occurred while processing the file: {str(e)}", None
def generate_interview_report(interview_history, language):
try:
report_language = language.strip().lower() if language else "english"
print('preferred report_language language:', report_language)
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
result = report_retrieval_chain.invoke({
"input": "Please provide a clinical report based on the following interview:",
"history": "\n".join(interview_history),
"language": report_language
})
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
pdf_path = create_pdf(report_content)
return report_content, pdf_path
except Exception as e:
return f"An error occurred while generating the report: {str(e)}", None
def create_pdf(content):
random_string = generate_random_string()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f'_report.pdf')
doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
styles = getSampleStyleSheet()
# Create a custom style for bold text
bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
# Create a custom style for normal text with justification
normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
flowables = []
for line in content.split('\n'):
# Use regex to find words surrounded by **
parts = re.split(r'(\*\*.*?\*\*)', line)
paragraph_parts = []
for part in parts:
if part.startswith('**') and part.endswith('**'):
# Bold text
bold_text = part.strip('**')
paragraph_parts.append(Paragraph(bold_text, bold_style))
else:
# Normal text
paragraph_parts.append(Paragraph(part, normal_style))
flowables.extend(paragraph_parts)
flowables.append(Spacer(1, 12)) # Add space between paragraphs
doc.build(flowables)
return temp_file.name