Spaces:
Runtime error
Runtime error
# from typing import Any, Coroutine | |
import openai | |
import os | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import Pinecone | |
from langchain.agents import initialize_agent | |
from langchain.agents import AgentType | |
from langchain.agents import Tool | |
from langchain.tools import BaseTool | |
from langchain.tools import DuckDuckGoSearchRun | |
from langchain.utilities import WikipediaAPIWrapper | |
from langchain.python import PythonREPL | |
import pinecone | |
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration | |
import gradio as gr | |
import time | |
class DB_Search(BaseTool): | |
name = "Vector Database Search" | |
description = "This is the vector database to search information firstly" | |
def _run(self, query: str) -> str: | |
response, source = QAQuery_p(query) | |
# response = "test db_search feedback" | |
return response | |
def _arun(self, query: str): | |
raise NotImplementedError("N/A") | |
Wikipedia = WikipediaAPIWrapper() | |
Netsearch = DuckDuckGoSearchRun() | |
Python_REPL = PythonREPL() | |
wikipedia_tool = Tool( | |
name = "Wikipedia Search", | |
func = Wikipedia.run(), | |
description = "Useful to search a topic, country or person when there is no availble information in vector database" | |
) | |
duckduckgo_tool = Tool( | |
name = "Duckduckgo Internet Search", | |
func = Python_REPL.run(), | |
description = "Useful to search information in internet when it is not available in other tools" | |
) | |
python_tool = Tool( | |
name = "Python REPL", | |
func = Netsearch.run(), | |
description = "Useful when you need python to answer questions. You should input python code." | |
) | |
tools = [DB_Search(), wikipedia_tool, duckduckgo_tool, python_tool] | |
os.environ["OPENAI_API_TYPE"] = "azure" | |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE") | |
os.environ["OPENAI_API_VERSION"] = "2023-05-15" | |
username = os.getenv("username") | |
password = os.getenv("password") | |
SysLock = os.getenv("SysLock") # 0=unlock 1=lock | |
chat = AzureChatOpenAI( | |
deployment_name="Chattester", | |
temperature=0, | |
) | |
embeddings = OpenAIEmbeddings(deployment="model_embedding", chunk_size=15) | |
pinecone.init( | |
api_key = os.getenv("pinecone_api_key"), | |
environment='asia-southeast1-gcp-free', | |
# openapi_config=openapi_config | |
) | |
index_name = 'stla-baby' | |
index = pinecone.Index(index_name) | |
# index.delete(delete_all=True, namespace='') | |
# print(pinecone.whoami()) | |
# print(index.describe_index_stats()) | |
llm = chat | |
agent = initialize_agent(tools, llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose = True, | |
handle_parsing_errors = True) | |
global vectordb | |
vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
global vectordb_p | |
vectordb_p = Pinecone.from_existing_index(index_name, embeddings) | |
# loader = DirectoryLoader('./documents', glob='**/*.txt') | |
# documents = loader.load() | |
# text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
# split_docs = text_splitter.split_documents(documents) | |
# print(split_docs) | |
# vectordb = Chroma.from_documents(split_docs, embeddings, persist_directory='db') | |
# question = "what is LCDV ?" | |
# rr = vectordb.similarity_search(query=question, k=4) | |
# vectordb.similarity_search(question) | |
# print(type(rr)) | |
# print(rr) | |
def chathmi(message, history): | |
# response = "I don't know" | |
# print(message) | |
response, source = QAQuery_p(message) | |
time.sleep(0.3) | |
print(history) | |
yield response | |
# yield history | |
def chathmi2(message, history): | |
try: | |
output = agent.run(message) | |
time.sleep(0.3) | |
print("History: ", history) | |
response = output | |
yield response | |
except Exception as e: | |
print("error:", e) | |
# yield history | |
# chatbot = gr.Chatbot().style(color_map =("blue", "pink")) | |
# chatbot = gr.Chatbot(color_map =("blue", "pink")) | |
demo = gr.ChatInterface( | |
chathmi2, | |
title="STLA BABY - YOUR FRIENDLY GUIDE ", | |
description= "v0.2: Powered by MECH Core Team", | |
) | |
# demo = gr.Interface( | |
# chathmi, | |
# ["text", "state"], | |
# [chatbot, "state"], | |
# allow_flagging="never", | |
# ) | |
def CreatDb_P(): | |
global vectordb_p | |
index_name = 'stla-baby' | |
loader = DirectoryLoader('./documents', glob='**/*.txt') | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
split_docs = text_splitter.split_documents(documents) | |
print(split_docs) | |
pinecone.Index(index_name).delete(delete_all=True, namespace='') | |
vectordb_p = Pinecone.from_documents(split_docs, embeddings, index_name = "stla-baby") | |
print("Pinecone Updated Done") | |
print(index.describe_index_stats()) | |
def QAQuery_p(question: str): | |
global vectordb_p | |
# vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
retriever = vectordb_p.as_retriever() | |
retriever.search_kwargs['k'] = int(os.getenv("search_kwargs_k")) | |
# retriever.search_kwargs['fetch_k'] = 100 | |
qa = RetrievalQA.from_chain_type(llm=chat, chain_type="stuff", | |
retriever=retriever, return_source_documents = True, | |
verbose = True) | |
# qa = VectorDBQA.from_chain_type(llm=chat, chain_type="stuff", vectorstore=vectordb, return_source_documents=True) | |
# res = qa.run(question) | |
res = qa({"query": question}) | |
print("-" * 20) | |
print("Question:", question) | |
# print("Answer:", res) | |
print("Answer:", res['result']) | |
print("-" * 20) | |
print("Source:", res['source_documents']) | |
response = res['result'] | |
# response = res['source_documents'] | |
source = res['source_documents'] | |
return response, source | |
def CreatDb(): | |
global vectordb | |
loader = DirectoryLoader('./documents', glob='**/*.txt') | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
split_docs = text_splitter.split_documents(documents) | |
print(split_docs) | |
vectordb = Chroma.from_documents(split_docs, embeddings, persist_directory='db') | |
vectordb.persist() | |
def QAQuery(question: str): | |
global vectordb | |
# vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
retriever = vectordb.as_retriever() | |
retriever.search_kwargs['k'] = 3 | |
# retriever.search_kwargs['fetch_k'] = 100 | |
qa = RetrievalQA.from_chain_type(llm=chat, chain_type="stuff", retriever=retriever, return_source_documents = True) | |
# qa = VectorDBQA.from_chain_type(llm=chat, chain_type="stuff", vectorstore=vectordb, return_source_documents=True) | |
# res = qa.run(question) | |
res = qa({"query": question}) | |
print("-" * 20) | |
print("Question:", question) | |
# print("Answer:", res) | |
print("Answer:", res['result']) | |
print("-" * 20) | |
print("Source:", res['source_documents']) | |
response = res['result'] | |
return response | |
# Used to complete content | |
def completeText(Text): | |
deployment_id="Chattester" | |
prompt = Text | |
completion = openai.Completion.create(deployment_id=deployment_id, | |
prompt=prompt, temperature=0) | |
print(f"{prompt}{completion['choices'][0]['text']}.") | |
# Used to chat | |
def chatText(Text): | |
deployment_id="Chattester" | |
conversation = [{"role": "system", "content": "You are a helpful assistant."}] | |
user_input = Text | |
conversation.append({"role": "user", "content": user_input}) | |
response = openai.ChatCompletion.create(messages=conversation, | |
deployment_id="Chattester") | |
print("\n" + response["choices"][0]["message"]["content"] + "\n") | |
if __name__ == '__main__': | |
# chatText("what is AI?") | |
# CreatDb() | |
# QAQuery("what is COFOR ?") | |
# CreatDb_P() | |
# QAQuery_p("what is GST ?") | |
if SysLock == "1": | |
demo.queue().launch(auth=(username, password)) | |
else: | |
demo.queue().launch() | |
pass | |