File size: 3,127 Bytes
fadf40f
 
 
 
 
 
 
 
0dae114
 
 
 
 
fadf40f
 
0dae114
 
 
fadf40f
 
 
0dae114
 
 
 
fadf40f
 
0dae114
 
 
 
fadf40f
 
0dae114
 
 
 
fadf40f
 
0dae114
 
 
 
fadf40f
0dae114
 
 
 
 
fadf40f
0dae114
 
 
 
 
fadf40f
0dae114
 
 
 
 
fadf40f
 
076d575
 
 
fadf40f
 
 
 
0dae114
fadf40f
 
 
 
 
 
 
 
 
 
 
 
 
076d575
 
fadf40f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import weave

from medrag_multi_modal.assistant import (
    FigureAnnotatorFromPageImage,
    LLMClient,
    MedQAAssistant,
)
from medrag_multi_modal.assistant.llm_client import (
    GOOGLE_MODELS,
    MISTRAL_MODELS,
    OPENAI_MODELS,
)
from medrag_multi_modal.retrieval import MedCPTRetriever

# Define constants
ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS

# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
    label="Project Name",
    value="ml-colabs/medrag-multi-modal",
    placeholder="wandb project name",
    help="format: wandb_username/wandb_project_name",
)
chunk_dataset_name = st.sidebar.text_input(
    label="Text Chunk WandB Dataset Name",
    value="grays-anatomy-chunks:v0",
    placeholder="wandb dataset name",
    help="format: wandb_dataset_name:version",
)
index_artifact_address = st.sidebar.text_input(
    label="WandB Index Artifact Address",
    value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
    placeholder="wandb artifact address",
    help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
image_artifact_address = st.sidebar.text_input(
    label="WandB Image Artifact Address",
    value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
    placeholder="wandb artifact address",
    help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
llm_client_model_name = st.sidebar.selectbox(
    label="LLM Client Model Name",
    options=ALL_AVAILABLE_MODELS,
    index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"),
    help="select a model from the list",
)
figure_extraction_model_name = st.sidebar.selectbox(
    label="Figure Extraction Model Name",
    options=ALL_AVAILABLE_MODELS,
    index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
    help="select a model from the list",
)
structured_output_model_name = st.sidebar.selectbox(
    label="Structured Output Model Name",
    options=ALL_AVAILABLE_MODELS,
    index=ALL_AVAILABLE_MODELS.index("gpt-4o"),
    help="select a model from the list",
)

# Streamlit app layout
st.title("MedQA Assistant App")

# Initialize Weave
weave.init(project_name=project_name)

# Initialize clients and assistants
llm_client = LLMClient(model_name=llm_client_model_name)
retriever = MedCPTRetriever.from_wandb_artifact(
    chunk_dataset_name=chunk_dataset_name,
    index_artifact_address=index_artifact_address,
)
figure_annotator = FigureAnnotatorFromPageImage(
    figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
    structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
    image_artifact_address=image_artifact_address,
)
medqa_assistant = MedQAAssistant(
    llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)

query = st.chat_input("Enter your question here")
if query:
    with st.chat_message("user"):
        st.markdown(query)
    response = medqa_assistant.predict(query=query)
    with st.chat_message("assistant"):
        st.markdown(response)