ilj commited on
Commit
bbb957e
·
1 Parent(s): 8e14c81

add gemini

Browse files
Files changed (2) hide show
  1. app.py +7 -2
  2. langchain_pipeline.py +14 -3
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import streamlit as st
2
- from langchain_pipeline import pipeline
3
 
4
- st.title("Composure AI")
 
 
 
 
 
5
 
6
  uploaded_file = st.file_uploader("Choose a file")
7
  if uploaded_file is not None:
 
1
  import streamlit as st
2
+ from langchain_pipeline import pipeline, model_names
3
 
4
+ st.title("Canarie AI Prototype")
5
+ st.subheader("Finding the canarie in the coal mine")
6
+
7
+ option = st.selectbox(
8
+ "Model",
9
+ model_names())
10
 
11
  uploaded_file = st.file_uploader("Choose a file")
12
  if uploaded_file is not None:
langchain_pipeline.py CHANGED
@@ -5,18 +5,29 @@ from langchain_astradb import AstraDBVectorStore
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_anthropic import ChatAnthropic
 
8
 
9
  ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
10
  ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
11
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
12
  ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"]
 
13
 
14
  collection_name = "ilj_test"
15
 
16
  embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
17
- model = ChatAnthropic(model='claude-3-sonnet-20240229')
18
 
19
- def pipeline(bytes):
 
 
 
 
 
 
 
 
 
 
20
  disclosure_text = high_level.extract_text(bytes)
21
  # disclosure_text = doc[0].page_content
22
  #
@@ -52,6 +63,6 @@ def pipeline(bytes):
52
  )
53
  val = prompt.format(context=related_docs, disclosure={disclosure_text})
54
 
55
- chat_response = model.invoke(input=val)
56
 
57
  return chat_response.content
 
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_anthropic import ChatAnthropic
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
 
10
  ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
11
  ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
12
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
13
  ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"]
14
+ GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
15
 
16
  collection_name = "ilj_test"
17
 
18
  embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
 
19
 
20
+ models = {
21
+ "claude-3": ChatAnthropic(model='claude-3-sonnet-20240229'),
22
+ "gemini-pro": ChatGoogleGenerativeAI(model="gemini-pro")
23
+ }
24
+
25
+
26
+ def model_names():
27
+ return models.keys()
28
+
29
+
30
+ def pipeline(bytes, model_name):
31
  disclosure_text = high_level.extract_text(bytes)
32
  # disclosure_text = doc[0].page_content
33
  #
 
63
  )
64
  val = prompt.format(context=related_docs, disclosure={disclosure_text})
65
 
66
+ chat_response = models[model_name].invoke(input=val)
67
 
68
  return chat_response.content