Spaces:
Sleeping
Sleeping
import streamlit as st | |
import weave | |
from dotenv import load_dotenv | |
from medrag_multi_modal.assistant import ( | |
FigureAnnotatorFromPageImage, | |
LLMClient, | |
MedQAAssistant, | |
) | |
from medrag_multi_modal.retrieval import MedCPTRetriever | |
# Load environment variables | |
load_dotenv() | |
# Sidebar for configuration settings | |
st.sidebar.title("Configuration Settings") | |
project_name = st.sidebar.text_input( | |
"Project Name", | |
"ml-colabs/medrag-multi-modal" | |
) | |
chunk_dataset_name = st.sidebar.text_input( | |
"Text Chunk WandB Dataset Name", | |
"grays-anatomy-chunks:v0" | |
) | |
index_artifact_address = st.sidebar.text_input( | |
"WandB Index Artifact Address", | |
"ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", | |
) | |
image_artifact_address = st.sidebar.text_input( | |
"WandB Image Artifact Address", | |
"ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", | |
) | |
llm_model_name = st.sidebar.text_input( | |
"LLM Client Model Name", | |
"gemini-1.5-flash" | |
) | |
figure_extraction_model_name = st.sidebar.text_input( | |
"Figure Extraction Model Name", | |
"pixtral-12b-2409" | |
) | |
structured_output_model_name = st.sidebar.text_input( | |
"Structured Output Model Name", | |
"gpt-4o" | |
) | |
# Initialize Weave | |
weave.init(project_name=project_name) | |
# Initialize clients and assistants | |
llm_client = LLMClient(model_name=llm_model_name) | |
retriever = MedCPTRetriever.from_wandb_artifact( | |
chunk_dataset_name=chunk_dataset_name, | |
index_artifact_address=index_artifact_address, | |
) | |
figure_annotator = FigureAnnotatorFromPageImage( | |
figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name), | |
structured_output_llm_client=LLMClient(model_name=structured_output_model_name), | |
image_artifact_address=image_artifact_address, | |
) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator | |
) | |
# Streamlit app layout | |
st.title("MedQA Assistant App") | |
# Initialize chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat thread section with user input and response | |
if query := st.chat_input("What medical question can I assist you with today?"): | |
# Add user message to chat history | |
st.session_state.chat_history.append({"role": "user", "content": query}) | |
with st.chat_message("user"): | |
st.markdown(query) | |
# Process query and get response | |
response = medqa_assistant.predict(query=query) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |