Spaces:
Sleeping
Sleeping
File size: 3,127 Bytes
fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 076d575 fadf40f 0dae114 fadf40f 076d575 fadf40f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import streamlit as st
import weave
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.assistant.llm_client import (
GOOGLE_MODELS,
MISTRAL_MODELS,
OPENAI_MODELS,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
# Define constants
ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS
# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
label="Project Name",
value="ml-colabs/medrag-multi-modal",
placeholder="wandb project name",
help="format: wandb_username/wandb_project_name",
)
chunk_dataset_name = st.sidebar.text_input(
label="Text Chunk WandB Dataset Name",
value="grays-anatomy-chunks:v0",
placeholder="wandb dataset name",
help="format: wandb_dataset_name:version",
)
index_artifact_address = st.sidebar.text_input(
label="WandB Index Artifact Address",
value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
placeholder="wandb artifact address",
help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
image_artifact_address = st.sidebar.text_input(
label="WandB Image Artifact Address",
value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
placeholder="wandb artifact address",
help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
llm_client_model_name = st.sidebar.selectbox(
label="LLM Client Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"),
help="select a model from the list",
)
figure_extraction_model_name = st.sidebar.selectbox(
label="Figure Extraction Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
help="select a model from the list",
)
structured_output_model_name = st.sidebar.selectbox(
label="Structured Output Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("gpt-4o"),
help="select a model from the list",
)
# Streamlit app layout
st.title("MedQA Assistant App")
# Initialize Weave
weave.init(project_name=project_name)
# Initialize clients and assistants
llm_client = LLMClient(model_name=llm_client_model_name)
retriever = MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name=chunk_dataset_name,
index_artifact_address=index_artifact_address,
)
figure_annotator = FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
image_artifact_address=image_artifact_address,
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
query = st.chat_input("Enter your question here")
if query:
with st.chat_message("user"):
st.markdown(query)
response = medqa_assistant.predict(query=query)
with st.chat_message("assistant"):
st.markdown(response)
|