medrag / app.py
mratanusarkar's picture
add: a basic streamlit app impl for gui
fadf40f
raw
history blame
2.76 kB
import streamlit as st
import weave
from dotenv import load_dotenv
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
# Load environment variables
load_dotenv()
# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
"Project Name",
"ml-colabs/medrag-multi-modal"
)
chunk_dataset_name = st.sidebar.text_input(
"Text Chunk WandB Dataset Name",
"grays-anatomy-chunks:v0"
)
index_artifact_address = st.sidebar.text_input(
"WandB Index Artifact Address",
"ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
)
image_artifact_address = st.sidebar.text_input(
"WandB Image Artifact Address",
"ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
)
llm_model_name = st.sidebar.text_input(
"LLM Client Model Name",
"gemini-1.5-flash"
)
figure_extraction_model_name = st.sidebar.text_input(
"Figure Extraction Model Name",
"pixtral-12b-2409"
)
structured_output_model_name = st.sidebar.text_input(
"Structured Output Model Name",
"gpt-4o"
)
# Initialize Weave
weave.init(project_name=project_name)
# Initialize clients and assistants
llm_client = LLMClient(model_name=llm_model_name)
retriever = MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name=chunk_dataset_name,
index_artifact_address=index_artifact_address,
)
figure_annotator = FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
image_artifact_address=image_artifact_address,
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
# Streamlit app layout
st.title("MedQA Assistant App")
# Initialize chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat thread section with user input and response
if query := st.chat_input("What medical question can I assist you with today?"):
# Add user message to chat history
st.session_state.chat_history.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)
# Process query and get response
response = medqa_assistant.predict(query=query)
st.session_state.chat_history.append({"role": "assistant", "content": response})
with st.chat_message("assistant"):
st.markdown(response)