Spaces:
Running
Running
import traceback | |
from datetime import datetime | |
from pathlib import Path | |
import os | |
import random | |
import string | |
import tempfile | |
import re | |
import io | |
import PyPDF2 | |
import docx | |
from reportlab.pdfgen import canvas | |
from reportlab.lib.pagesizes import letter | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer | |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
from reportlab.lib.enums import TA_JUSTIFY | |
from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech | |
from knowledge_retrieval import setup_knowledge_retrieval, generate_report | |
# Initialize settings | |
n_of_questions = n_of_questions() | |
current_datetime = datetime.now() | |
human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M") | |
current_date = current_datetime.strftime("%Y-%m-%d") | |
# Initialize the model and retrieval chain | |
try: | |
llm = load_model(openai_api_key) | |
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm) | |
knowledge_base_connected = True | |
print("Successfully connected to the knowledge base.") | |
except Exception as e: | |
print(f"Error initializing the model or retrieval chain: {str(e)}") | |
knowledge_base_connected = False | |
print("Falling back to basic mode without knowledge base.") | |
question_count = 0 | |
interview_history = [] | |
last_audio_path = None # Variable to store the path of the last audio file | |
initial_audio_path = None # Variable to store the path of the initial audio file | |
language = None | |
def generate_random_string(length=5): | |
return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) | |
def respond(message, history, voice, selected_interviewer): | |
global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain | |
if not isinstance(history, list): | |
history = [] | |
if not history or not history[-1]: | |
history.append(["", ""]) | |
# Extract the actual message text | |
if isinstance(message, list): | |
message = message[-1][0] if message and isinstance(message[-1], list) else message[-1] | |
question_count += 1 | |
interview_history.append(f"Q{question_count}: {message}") | |
history_str = "\n".join(interview_history) | |
print("Starting interview", question_count) | |
try: | |
if knowledge_base_connected: | |
if question_count == 1: | |
# Capture the language from the first response | |
language = message.strip().lower() | |
# Reinitialize the interview chain with the new language | |
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval( | |
llm, language, selected_interviewer) | |
if question_count < n_of_questions: | |
result = interview_retrieval_chain.invoke({ | |
"input": f"Based on the patient's statement: '{message}', what should be the next question?", | |
"history": history_str, | |
"question_number": question_count + 1, | |
"language": language | |
}) | |
question = result.get("answer", f"Can you tell me more about that? (in {language})") | |
else: | |
result = generate_report(report_retrieval_chain, interview_history, language) | |
question = result | |
speech_file_path = None # Skip audio generation for the report | |
if question: | |
random_suffix = generate_random_string() | |
speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3" | |
convert_text_to_speech(question, speech_file_path, voice) | |
print(f"Question {question_count} saved as audio at {speech_file_path}") | |
# Remove the last audio file if it exists | |
if last_audio_path and os.path.exists(last_audio_path): | |
os.remove(last_audio_path) | |
last_audio_path = speech_file_path | |
else: | |
speech_file_path = None # Skip audio generation for the report | |
else: | |
# Fallback mode without knowledge base | |
question = f"Can you elaborate on that? (in {language})" | |
if question_count < n_of_questions: | |
speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3" | |
convert_text_to_speech(question, speech_file_path, voice) | |
print(f"Question {question_count} saved as audio at {speech_file_path}") | |
if last_audio_path and os.path.exists(last_audio_path): | |
os.remove(last_audio_path) | |
last_audio_path = speech_file_path | |
else: | |
speech_file_path = None | |
history[-1][1] = f"{question}" | |
# Remove the initial question audio file after the first user response | |
if initial_audio_path and os.path.exists(initial_audio_path): | |
os.remove(initial_audio_path) | |
initial_audio_path = None | |
# Clean up older files based on question_count | |
if question_count > 1: | |
previous_audio_path = Path(__file__).parent / f"question_{question_count-1}_{random_suffix}.mp3" | |
if os.path.exists(previous_audio_path): | |
os.remove(previous_audio_path) | |
return history, str(speech_file_path) if speech_file_path else None | |
except Exception as e: | |
print(f"Error in retrieval chain: {str(e)}") | |
print(traceback.format_exc()) | |
return history, None | |
def reset_interview(): | |
"""Reset the interview state.""" | |
global question_count, interview_history, last_audio_path, initial_audio_path | |
question_count = 0 | |
interview_history = [] | |
if last_audio_path and os.path.exists(last_audio_path): | |
os.remove(last_audio_path) | |
last_audio_path = None | |
initial_audio_path = None | |
def read_file(file): | |
if file is None: | |
return "No file uploaded" | |
if isinstance(file, str): | |
with open(file, 'r', encoding='utf-8') as f: | |
return f.read() | |
if hasattr(file, 'name'): # Check if it's a file-like object | |
if file.name.endswith('.txt'): | |
return file.content | |
elif file.name.endswith('.pdf'): | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content)) | |
return "\n".join(page.extract_text() for page in pdf_reader.pages) | |
elif file.name.endswith('.docx'): | |
doc = docx.Document(io.BytesIO(file.content)) | |
return "\n".join(paragraph.text for paragraph in doc.paragraphs) | |
else: | |
return "Unsupported file format" | |
return "Unable to read file" | |
def generate_report_from_file(file, language): | |
try: | |
file_content = read_file(file) | |
if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file": | |
return file_content | |
file_content = file_content[:100000] | |
report_language = language.strip().lower() if language else "english" | |
print('preferred language:', report_language) | |
print(f"Generating report in language: {report_language}") # For debugging | |
# Reinitialize the report chain with the new language | |
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language) | |
result = report_retrieval_chain.invoke({ | |
"input": "Please provide a clinical report based on the following content:", | |
"history": file_content, | |
"language": report_language | |
}) | |
report_content = result.get("answer", "Unable to generate report due to insufficient information.") | |
pdf_path = create_pdf(report_content) | |
return report_content, pdf_path | |
except Exception as e: | |
return f"An error occurred while processing the file: {str(e)}", None | |
def generate_interview_report(interview_history, language): | |
try: | |
report_language = language.strip().lower() if language else "english" | |
print('preferred report_language language:', report_language) | |
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language) | |
result = report_retrieval_chain.invoke({ | |
"input": "Please provide a clinical report based on the following interview:", | |
"history": "\n".join(interview_history), | |
"language": report_language | |
}) | |
report_content = result.get("answer", "Unable to generate report due to insufficient information.") | |
pdf_path = create_pdf(report_content) | |
return report_content, pdf_path | |
except Exception as e: | |
return f"An error occurred while generating the report: {str(e)}", None | |
def create_pdf(content): | |
random_string = generate_random_string() | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f'_report.pdf') | |
doc = SimpleDocTemplate(temp_file.name, pagesize=letter) | |
styles = getSampleStyleSheet() | |
# Create a custom style for bold text | |
bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10) | |
# Create a custom style for normal text with justification | |
normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY) | |
flowables = [] | |
for line in content.split('\n'): | |
# Use regex to find words surrounded by ** | |
parts = re.split(r'(\*\*.*?\*\*)', line) | |
paragraph_parts = [] | |
for part in parts: | |
if part.startswith('**') and part.endswith('**'): | |
# Bold text | |
bold_text = part.strip('**') | |
paragraph_parts.append(Paragraph(bold_text, bold_style)) | |
else: | |
# Normal text | |
paragraph_parts.append(Paragraph(part, normal_style)) | |
flowables.extend(paragraph_parts) | |
flowables.append(Spacer(1, 12)) # Add space between paragraphs | |
doc.build(flowables) | |
return temp_file.name |