mratanusarkar commited on
Commit
fadf40f
·
1 Parent(s): a19b9c3

add: a basic streamlit app impl for gui

Browse files
Files changed (2) hide show
  1. app.py +87 -0
  2. pyproject.toml +11 -5
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import weave
3
+ from dotenv import load_dotenv
4
+
5
+ from medrag_multi_modal.assistant import (
6
+ FigureAnnotatorFromPageImage,
7
+ LLMClient,
8
+ MedQAAssistant,
9
+ )
10
+ from medrag_multi_modal.retrieval import MedCPTRetriever
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Sidebar for configuration settings
16
+ st.sidebar.title("Configuration Settings")
17
+ project_name = st.sidebar.text_input(
18
+ "Project Name",
19
+ "ml-colabs/medrag-multi-modal"
20
+ )
21
+ chunk_dataset_name = st.sidebar.text_input(
22
+ "Text Chunk WandB Dataset Name",
23
+ "grays-anatomy-chunks:v0"
24
+ )
25
+ index_artifact_address = st.sidebar.text_input(
26
+ "WandB Index Artifact Address",
27
+ "ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
28
+ )
29
+ image_artifact_address = st.sidebar.text_input(
30
+ "WandB Image Artifact Address",
31
+ "ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
32
+ )
33
+ llm_model_name = st.sidebar.text_input(
34
+ "LLM Client Model Name",
35
+ "gemini-1.5-flash"
36
+ )
37
+ figure_extraction_model_name = st.sidebar.text_input(
38
+ "Figure Extraction Model Name",
39
+ "pixtral-12b-2409"
40
+ )
41
+ structured_output_model_name = st.sidebar.text_input(
42
+ "Structured Output Model Name",
43
+ "gpt-4o"
44
+ )
45
+
46
+ # Initialize Weave
47
+ weave.init(project_name=project_name)
48
+
49
+ # Initialize clients and assistants
50
+ llm_client = LLMClient(model_name=llm_model_name)
51
+ retriever = MedCPTRetriever.from_wandb_artifact(
52
+ chunk_dataset_name=chunk_dataset_name,
53
+ index_artifact_address=index_artifact_address,
54
+ )
55
+ figure_annotator = FigureAnnotatorFromPageImage(
56
+ figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
57
+ structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
58
+ image_artifact_address=image_artifact_address,
59
+ )
60
+ medqa_assistant = MedQAAssistant(
61
+ llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
62
+ )
63
+
64
+ # Streamlit app layout
65
+ st.title("MedQA Assistant App")
66
+
67
+ # Initialize chat history
68
+ if "chat_history" not in st.session_state:
69
+ st.session_state.chat_history = []
70
+
71
+ # Display chat messages from history on app rerun
72
+ for message in st.session_state.chat_history:
73
+ with st.chat_message(message["role"]):
74
+ st.markdown(message["content"])
75
+
76
+ # Chat thread section with user input and response
77
+ if query := st.chat_input("What medical question can I assist you with today?"):
78
+ # Add user message to chat history
79
+ st.session_state.chat_history.append({"role": "user", "content": query})
80
+ with st.chat_message("user"):
81
+ st.markdown(query)
82
+
83
+ # Process query and get response
84
+ response = medqa_assistant.predict(query=query)
85
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
86
+ with st.chat_message("assistant"):
87
+ st.markdown(response)
pyproject.toml CHANGED
@@ -44,9 +44,13 @@ dependencies = [
44
  "jsonlines>=4.0.0",
45
  "opencv-python>=4.10.0.84",
46
  "openai>=1.52.2",
 
47
  ]
48
 
49
  [project.optional-dependencies]
 
 
 
50
  core = [
51
  "adapters>=1.0.0",
52
  "bm25s[full]>=0.2.2",
@@ -74,10 +78,12 @@ core = [
74
  "opencv-python>=4.10.0.84",
75
  "openai>=1.52.2",
76
  ]
77
-
78
- dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
79
-
80
-
 
 
81
  docs = [
82
  "mkdocs>=1.6.1",
83
  "mkdocstrings>=0.26.1",
@@ -91,4 +97,4 @@ docs = [
91
 
92
 
93
  [tool.pytest.ini_options]
94
- pythonpath = "."
 
44
  "jsonlines>=4.0.0",
45
  "opencv-python>=4.10.0.84",
46
  "openai>=1.52.2",
47
+ "streamlit>=1.39.0",
48
  ]
49
 
50
  [project.optional-dependencies]
51
+ app = [
52
+ "streamlit>=1.39.0",
53
+ ]
54
  core = [
55
  "adapters>=1.0.0",
56
  "bm25s[full]>=0.2.2",
 
78
  "opencv-python>=4.10.0.84",
79
  "openai>=1.52.2",
80
  ]
81
+ dev = [
82
+ "pytest>=8.3.3",
83
+ "isort>=5.13.2",
84
+ "black>=24.10.0",
85
+ "ruff>=0.6.9",
86
+ ]
87
  docs = [
88
  "mkdocs>=1.6.1",
89
  "mkdocstrings>=0.26.1",
 
97
 
98
 
99
  [tool.pytest.ini_options]
100
+ pythonpath = "."