mratanusarkar commited on
Commit
0dae114
·
1 Parent(s): fadf40f

update: ui components for settings input

Browse files
Files changed (1) hide show
  1. app.py +40 -18
app.py CHANGED
@@ -7,47 +7,69 @@ from medrag_multi_modal.assistant import (
7
  LLMClient,
8
  MedQAAssistant,
9
  )
 
 
 
 
 
10
  from medrag_multi_modal.retrieval import MedCPTRetriever
11
 
12
  # Load environment variables
13
  load_dotenv()
14
 
 
 
 
15
  # Sidebar for configuration settings
16
  st.sidebar.title("Configuration Settings")
17
  project_name = st.sidebar.text_input(
18
- "Project Name",
19
- "ml-colabs/medrag-multi-modal"
 
 
20
  )
21
  chunk_dataset_name = st.sidebar.text_input(
22
- "Text Chunk WandB Dataset Name",
23
- "grays-anatomy-chunks:v0"
 
 
24
  )
25
  index_artifact_address = st.sidebar.text_input(
26
- "WandB Index Artifact Address",
27
- "ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
 
 
28
  )
29
  image_artifact_address = st.sidebar.text_input(
30
- "WandB Image Artifact Address",
31
- "ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
 
 
32
  )
33
- llm_model_name = st.sidebar.text_input(
34
- "LLM Client Model Name",
35
- "gemini-1.5-flash"
 
 
36
  )
37
- figure_extraction_model_name = st.sidebar.text_input(
38
- "Figure Extraction Model Name",
39
- "pixtral-12b-2409"
 
 
40
  )
41
- structured_output_model_name = st.sidebar.text_input(
42
- "Structured Output Model Name",
43
- "gpt-4o"
 
 
44
  )
45
 
46
  # Initialize Weave
47
  weave.init(project_name=project_name)
48
 
49
  # Initialize clients and assistants
50
- llm_client = LLMClient(model_name=llm_model_name)
51
  retriever = MedCPTRetriever.from_wandb_artifact(
52
  chunk_dataset_name=chunk_dataset_name,
53
  index_artifact_address=index_artifact_address,
 
7
  LLMClient,
8
  MedQAAssistant,
9
  )
10
+ from medrag_multi_modal.assistant.llm_client import (
11
+ GOOGLE_MODELS,
12
+ MISTRAL_MODELS,
13
+ OPENAI_MODELS,
14
+ )
15
  from medrag_multi_modal.retrieval import MedCPTRetriever
16
 
17
  # Load environment variables
18
  load_dotenv()
19
 
20
+ # Define constants
21
+ ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS
22
+
23
  # Sidebar for configuration settings
24
  st.sidebar.title("Configuration Settings")
25
  project_name = st.sidebar.text_input(
26
+ label="Project Name",
27
+ value="ml-colabs/medrag-multi-modal",
28
+ placeholder="wandb project name",
29
+ help="format: wandb_username/wandb_project_name",
30
  )
31
  chunk_dataset_name = st.sidebar.text_input(
32
+ label="Text Chunk WandB Dataset Name",
33
+ value="grays-anatomy-chunks:v0",
34
+ placeholder="wandb dataset name",
35
+ help="format: wandb_dataset_name:version",
36
  )
37
  index_artifact_address = st.sidebar.text_input(
38
+ label="WandB Index Artifact Address",
39
+ value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
40
+ placeholder="wandb artifact address",
41
+ help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
42
  )
43
  image_artifact_address = st.sidebar.text_input(
44
+ label="WandB Image Artifact Address",
45
+ value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
46
+ placeholder="wandb artifact address",
47
+ help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
48
  )
49
+ llm_client_model_name = st.sidebar.selectbox(
50
+ label="LLM Client Model Name",
51
+ options=ALL_AVAILABLE_MODELS,
52
+ index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"),
53
+ help="select a model from the list",
54
  )
55
+ figure_extraction_model_name = st.sidebar.selectbox(
56
+ label="Figure Extraction Model Name",
57
+ options=ALL_AVAILABLE_MODELS,
58
+ index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
59
+ help="select a model from the list",
60
  )
61
+ structured_output_model_name = st.sidebar.selectbox(
62
+ label="Structured Output Model Name",
63
+ options=ALL_AVAILABLE_MODELS,
64
+ index=ALL_AVAILABLE_MODELS.index("gpt-4o"),
65
+ help="select a model from the list",
66
  )
67
 
68
  # Initialize Weave
69
  weave.init(project_name=project_name)
70
 
71
  # Initialize clients and assistants
72
+ llm_client = LLMClient(model_name=llm_client_model_name)
73
  retriever = MedCPTRetriever.from_wandb_artifact(
74
  chunk_dataset_name=chunk_dataset_name,
75
  index_artifact_address=index_artifact_address,