philmui commited on
Commit
f57b8d4
β€’
1 Parent(s): a7f3b3b

added semantic search of local books

Browse files
Files changed (6) hide show
  1. .gitignore +4 -0
  2. agents.py +63 -25
  3. app.py +5 -2
  4. data/machiavelli-the-prince.txt +0 -0
  5. data/sunzi-art-of-war.txt +0 -0
  6. models.py +35 -2
.gitignore CHANGED
@@ -158,3 +158,7 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+
162
+ # ChromaDB
163
+ db/
164
+ chromadb/
agents.py CHANGED
@@ -11,7 +11,9 @@ from langchain.schema import HumanMessage
11
  from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
12
  HumanMessagePromptTemplate
13
  from models import load_chat_agent, load_chained_agent, load_sales_agent, \
14
- load_sqlite_agent
 
 
15
 
16
  import logging
17
 
@@ -68,6 +70,59 @@ def chatAgent(chat_message):
68
  output = "Please rephrase and try chat again."
69
  return output
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def agentController(question_text, model_name):
73
  output = ""
@@ -78,7 +133,13 @@ def agentController(question_text, model_name):
78
  elif is_magic(question_text, DIGITAL_MAGIC_TOKENS):
79
  output = chinookAgent(question_text, model_name)
80
  print(f"πŸ”Ή chinookAgent: {output}")
81
- else:
 
 
 
 
 
 
82
  try:
83
  instruction = instruct_prompt.format(query=question_text)
84
  logger.info(f"instruction: {instruction}")
@@ -94,26 +155,3 @@ def agentController(question_text, model_name):
94
  logger.error(e)
95
 
96
  return output
97
-
98
-
99
- def salesAgent(instruction):
100
- output = ""
101
- try:
102
- agent = load_sales_agent(verbose=True)
103
- output = agent.run(instruction)
104
- print("panda> " + output)
105
- except Exception as e:
106
- logger.error(e)
107
- output = f"Rephrasing your prompt could get better sales results {e}"
108
- return output
109
-
110
- def chinookAgent(instruction, model_name):
111
- output = ""
112
- try:
113
- agent = load_sqlite_agent(model_name)
114
- output = agent.run(instruction)
115
- print("chinook> " + output)
116
- except Exception as e:
117
- logger.error(e)
118
- output = "Rephrasing your prompt could get better db results {e}"
119
- return output
 
11
  from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
12
  HumanMessagePromptTemplate
13
  from models import load_chat_agent, load_chained_agent, load_sales_agent, \
14
+ load_sqlite_agent, load_book_agent
15
+
16
+ import openai, numpy as np
17
 
18
  import logging
19
 
 
70
  output = "Please rephrase and try chat again."
71
  return output
72
 
73
+ def salesAgent(instruction):
74
+ output = ""
75
+ try:
76
+ agent = load_sales_agent(verbose=True)
77
+ output = agent.run(instruction)
78
+ print("panda> " + output)
79
+ except Exception as e:
80
+ logger.error(e)
81
+ output = f"Rephrasing your prompt could get better sales results {e}"
82
+ return output
83
+
84
+ def chinookAgent(instruction, model_name):
85
+ output = ""
86
+ try:
87
+ agent = load_sqlite_agent(model_name)
88
+ output = agent.run(instruction)
89
+ print("chinook> " + output)
90
+ except Exception as e:
91
+ logger.error(e)
92
+ output = "Rephrasing your prompt could get better db results {e}"
93
+ return output
94
+
95
+ def semantically_similar(string1, string2):
96
+ #
97
+ # proper way to do this is to use a
98
+ # vector DB (chroma, pinecone, ...)
99
+ #
100
+ response = openai.Embedding.create(
101
+ input=[string1, string2],
102
+ engine="text-similarity-davinci-001"
103
+ )
104
+ embedding_a = response['data'][0]['embedding']
105
+ embedding_b = response['data'][1]['embedding']
106
+ similarity_score = np.dot(embedding_a, embedding_b)
107
+ logger.info(f"similarity: {similarity_score}")
108
+
109
+ return similarity_score > 0.8
110
+
111
+
112
+ def bookAgent(query):
113
+ output = ""
114
+ try:
115
+ agent = load_book_agent(True)
116
+ result = agent({
117
+ "query": query
118
+ })
119
+ logger.info(f"book response: {result['result']}")
120
+ output = result['result']
121
+ except Exception as e:
122
+ logger.error(e)
123
+ output = "Rephrasing your prompt for the book agent{e}"
124
+ return output
125
+
126
 
127
  def agentController(question_text, model_name):
128
  output = ""
 
133
  elif is_magic(question_text, DIGITAL_MAGIC_TOKENS):
134
  output = chinookAgent(question_text, model_name)
135
  print(f"πŸ”Ή chinookAgent: {output}")
136
+ elif semantically_similar(question_text, "fight a war"):
137
+ output = bookAgent(question_text)
138
+ print(f"πŸ”Ή bookAgent: {output}")
139
+ elif semantically_similar(question_text, "how to govern"):
140
+ output = bookAgent(question_text)
141
+ print(f"πŸ”Ή bookAgent: {output}")
142
+ else: # reasoning agents
143
  try:
144
  instruction = instruct_prompt.format(query=question_text)
145
  logger.info(f"instruction: {instruction}")
 
155
  logger.error(e)
156
 
157
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -9,7 +9,7 @@
9
 
10
  import streamlit as st
11
  from pprint import pprint
12
- from agents import agentController, salesAgent, chinookAgent, chatAgent
13
 
14
  ##############################################################################
15
 
@@ -104,7 +104,10 @@ with col2:
104
  value="πŸ”Ή For my company, what is the total sales " +
105
  "broken down by month?\n" +
106
  "πŸ”Ή How many total artists are there in each "+
107
- "genres in our digital media database?")
 
 
 
108
 
109
  with col3:
110
  st.markdown("__Enhanced reasoning__ [🎡](https://www.youtube.com/watch?v=hTTUaImgCyU&t=62s)")
 
9
 
10
  import streamlit as st
11
  from pprint import pprint
12
+ from agents import agentController , salesAgent, chinookAgent, chatAgent
13
 
14
  ##############################################################################
15
 
 
104
  value="πŸ”Ή For my company, what is the total sales " +
105
  "broken down by month?\n" +
106
  "πŸ”Ή How many total artists are there in each "+
107
+ "genres in our digital media database?\n" +
108
+ "πŸ”Ή How to best govern a city? (The Prince)\n" +
109
+ "πŸ”Ή How to win a war? (Art of War)",
110
+ )
111
 
112
  with col3:
113
  st.markdown("__Enhanced reasoning__ [🎡](https://www.youtube.com/watch?v=hTTUaImgCyU&t=62s)")
data/machiavelli-the-prince.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/sunzi-art-of-war.txt ADDED
The diff for this file is too large to render. See raw diff
 
models.py CHANGED
@@ -10,9 +10,14 @@ import pandas as pd
10
 
11
  from langchain.agents import AgentType, load_tools, initialize_agent,\
12
  create_pandas_dataframe_agent
 
13
  from langchain.chat_models import ChatOpenAI
14
  from langchain.llms import OpenAI
15
- from langchain import SQLDatabase, SQLDatabaseChain, HuggingFaceHub
 
 
 
 
16
 
17
  OPENAI_LLMS = [
18
  'text-davinci-003',
@@ -45,10 +50,38 @@ def createLLM(model_name="text-davinci-003", temperature=0):
45
  model_kwargs={"temperature":1e-10})
46
  return llm
47
 
48
-
49
  def load_chat_agent(verbose=True):
50
  return createLLM(OPENAI_CHAT_LLMS[0], temperature=0.5)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def load_sales_agent(verbose=True):
53
  '''
54
  Hard-coded agent that gates an internal sales CSV file for demo
 
10
 
11
  from langchain.agents import AgentType, load_tools, initialize_agent,\
12
  create_pandas_dataframe_agent
13
+ from langchain import SQLDatabase, SQLDatabaseChain, HuggingFaceHub
14
  from langchain.chat_models import ChatOpenAI
15
  from langchain.llms import OpenAI
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.document_loaders import DirectoryLoader, TextLoader
18
+ from langchain.embeddings.openai import OpenAIEmbeddings
19
+ from langchain.vectorstores import Chroma
20
+ from langchain.text_splitter import CharacterTextSplitter
21
 
22
  OPENAI_LLMS = [
23
  'text-davinci-003',
 
50
  model_kwargs={"temperature":1e-10})
51
  return llm
52
 
 
53
  def load_chat_agent(verbose=True):
54
  return createLLM(OPENAI_CHAT_LLMS[0], temperature=0.5)
55
 
56
+ import os
57
+ import chromadb
58
+ from chromadb.config import Settings
59
+ DB_DIR = "./db"
60
+
61
+ def load_book_agent(verbose=True):
62
+ retriever = None
63
+ embeddings = OpenAIEmbeddings(openai_api_key = os.environ['OPENAI_API_KEY'])
64
+
65
+ if not os.path.exists(DB_DIR):
66
+ loader = DirectoryLoader(path="./data/", glob="**/*.txt")
67
+ docs = loader.load()
68
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
69
+ text_chunks = text_splitter.split_documents(documents=docs)
70
+ docsearch = Chroma.from_documents(text_chunks, embeddings,
71
+ persist_directory="./db")
72
+ retriever = docsearch.as_retriever()
73
+ else:
74
+ vectordb = Chroma(persist_directory=DB_DIR,
75
+ embedding_function=embeddings)
76
+ retriever = vectordb.as_retriever()
77
+
78
+ qa = RetrievalQA.from_chain_type(llm = OpenAI(temperature=0.9),
79
+ chain_type="stuff",
80
+ retriever=retriever,
81
+ return_source_documents=True
82
+ )
83
+ return qa
84
+
85
  def load_sales_agent(verbose=True):
86
  '''
87
  Hard-coded agent that gates an internal sales CSV file for demo