Spaces:
Runtime error
Runtime error
# from typing import Any, Coroutine | |
import openai | |
import os | |
# from langchain.vectorstores import Chroma | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import Pinecone | |
from langchain.agents import initialize_agent | |
from langchain.agents import AgentType | |
from langchain.agents import Tool | |
# from langchain.agents import load_tools | |
from langchain.tools import BaseTool | |
from langchain.tools import DuckDuckGoSearchRun | |
from langchain.utilities import WikipediaAPIWrapper | |
from langchain.python import PythonREPL | |
from langchain.chains import LLMMathChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.agents import ZeroShotAgent, AgentExecutor | |
from langchain.agents import OpenAIMultiFunctionsAgent | |
from langchain.prompts import MessagesPlaceholder | |
from langchain.schema.messages import ( | |
AIMessage, | |
BaseMessage, | |
FunctionMessage, | |
SystemMessage, | |
) | |
from langchain import LLMChain | |
import azure.cognitiveservices.speech as speechsdk | |
import requests | |
import sys | |
import pinecone | |
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration | |
import gradio as gr | |
import time | |
import glob | |
from typing import List | |
from multiprocessing import Pool | |
from tqdm import tqdm | |
from langchain.document_loaders import ( | |
CSVLoader, | |
EverNoteLoader, | |
PyMuPDFLoader, | |
TextLoader, | |
UnstructuredEmailLoader, | |
UnstructuredEPubLoader, | |
UnstructuredHTMLLoader, | |
UnstructuredMarkdownLoader, | |
UnstructuredODTLoader, | |
UnstructuredPowerPointLoader, | |
UnstructuredWordDocumentLoader, | |
) | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.docstore.document import Document | |
import langchain | |
langchain.debug = True | |
global memory2 | |
memory2 = ConversationBufferWindowMemory(memory_key="chat_history") | |
global memory_openai | |
memory_openai = ConversationBufferWindowMemory(memory_key="memory", return_messages=True) | |
global last_request | |
last_request = "" | |
# Custom document loaders | |
class MyElmLoader(UnstructuredEmailLoader): | |
"""Wrapper to fallback to text/plain when default does not work""" | |
def load(self) -> List[Document]: | |
"""Wrapper adding fallback for elm without html""" | |
try: | |
try: | |
doc = UnstructuredEmailLoader.load(self) | |
except ValueError as e: | |
if 'text/html content not found in email' in str(e): | |
# Try plain text | |
self.unstructured_kwargs["content_source"]="text/plain" | |
doc = UnstructuredEmailLoader.load(self) | |
else: | |
raise | |
except Exception as e: | |
# Add file_path to exception message | |
raise type(e)(f"{self.file_path}: {e}") from e | |
return doc | |
LOADER_MAPPING = { | |
".csv": (CSVLoader, {}), | |
# ".docx": (Docx2txtLoader, {}), | |
".doc": (UnstructuredWordDocumentLoader, {}), | |
".docx": (UnstructuredWordDocumentLoader, {}), | |
".enex": (EverNoteLoader, {}), | |
".eml": (MyElmLoader, {}), | |
".epub": (UnstructuredEPubLoader, {}), | |
".html": (UnstructuredHTMLLoader, {}), | |
".md": (UnstructuredMarkdownLoader, {}), | |
".odt": (UnstructuredODTLoader, {}), | |
".pdf": (PyMuPDFLoader, {}), | |
".ppt": (UnstructuredPowerPointLoader, {}), | |
".pptx": (UnstructuredPowerPointLoader, {}), | |
".txt": (TextLoader, {"encoding": "utf8"}), | |
# Add more mappings for other file extensions and loaders as needed | |
} | |
source_directory = 'Upload Files' | |
global file_list_loaded | |
file_list_loaded = '' | |
chunk_size = 500 | |
chunk_overlap = 300 | |
global Audio_output | |
Audio_output = [] | |
global Filename_Chatbot | |
Filename_Chatbot = "" | |
def load_single_document(file_path: str) -> List[Document]: | |
ext = "." + file_path.rsplit(".", 1)[-1] | |
if ext in LOADER_MAPPING: | |
loader_class, loader_args = LOADER_MAPPING[ext] | |
loader = loader_class(file_path, **loader_args) | |
return loader.load() | |
raise ValueError(f"Unsupported file extension '{ext}'") | |
def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]: | |
""" | |
Loads all documents from the source documents directory, ignoring specified files | |
""" | |
all_files = [] | |
for ext in LOADER_MAPPING: | |
all_files.extend( | |
glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True) | |
) | |
filtered_files = [file_path for file_path in all_files if file_path not in ignored_files] | |
with Pool(processes=os.cpu_count()) as pool: | |
results = [] | |
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar: | |
for i, docs in enumerate(pool.imap_unordered(load_single_document, filtered_files)): | |
results.extend(docs) | |
pbar.update() | |
return results | |
def load_documents_2(all_files: List[str] = [], ignored_files: List[str] = []) -> List[Document]: | |
""" | |
Loads all documents from the source documents directory, ignoring specified files | |
""" | |
# all_files = [] | |
# for ext in LOADER_MAPPING: | |
# all_files.extend( | |
# glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True) | |
# ) | |
filtered_files = [file_path for file_path in all_files if file_path not in ignored_files] | |
results = [] | |
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar: | |
for file in filtered_files: | |
docs = load_single_document(file) | |
results.extend(docs) | |
pbar.update() | |
return results | |
def process_documents(ignored_files: List[str] = []) -> List[Document]: | |
""" | |
Load documents and split in chunks | |
""" | |
print(f"Loading documents from {source_directory}") | |
documents = load_documents(source_directory, ignored_files) | |
if not documents: | |
print("No new documents to load") | |
exit(0) | |
print(f"Loaded {len(documents)} new documents from {source_directory}") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
texts = text_splitter.split_documents(documents) | |
print(f"Split into {len(texts)} chunks of text (max. {chunk_size} tokens each)") | |
return texts | |
def process_documents_2(ignored_files: List[str] = []) -> List[Document]: | |
""" | |
Load documents and split in chunks | |
""" | |
global file_list_loaded | |
print(f"Loading documents from {source_directory}") | |
print("File Path to start processing:", file_list_loaded) | |
documents = load_documents_2(file_list_loaded, ignored_files) | |
if not documents: | |
print("No new documents to load") | |
exit(0) | |
print(f"Loaded {len(documents)} new documents from {source_directory}") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
texts = text_splitter.split_documents(documents) | |
print(f"Split into {len(texts)} chunks of text (max. {chunk_size} tokens each)") | |
return texts | |
def UpdateDb(): | |
global vectordb_p | |
# pinecone.Index(index_name).delete(delete_all=True, namespace='') | |
# collection = vectordb_p.get() | |
# split_docs = process_documents([metadata['source'] for metadata in collection['metadatas']]) | |
# split_docs = process_documents() | |
split_docs = process_documents_2() | |
tt = len(split_docs) | |
print(split_docs[tt-1]) | |
print(f"Creating embeddings. May take some minutes...") | |
vectordb_p = Pinecone.from_documents(split_docs, embeddings, index_name = "stla-baby") | |
print("Pinecone Updated Done") | |
print(index.describe_index_stats()) | |
class DB_Search(BaseTool): | |
name = "Vector_Database_Search" | |
description = "This is the internal database to search information firstly. If information is found, it is trustful." | |
def _run(self, query: str) -> str: | |
response, source = QAQuery_p(query) | |
# response = "test db_search feedback" | |
return response | |
def _arun(self, query: str): | |
raise NotImplementedError("N/A") | |
class DB_Search2(BaseTool): | |
name = "Vector Database Search" | |
description = "This is the internal database to search information firstly. If information is found, it is trustful." | |
def _run(self, query: str) -> str: | |
response, source = QAQuery_p(query) | |
# response = "test db_search feedback" | |
return response | |
def _arun(self, query: str): | |
raise NotImplementedError("N/A") | |
def Text2Sound(text): | |
speech_config = speechsdk.SpeechConfig(subscription=os.getenv('SPEECH_KEY'), region=os.getenv('SPEECH_REGION')) | |
audio_config = speechsdk.audio.AudioOutputConfig(use_default_speaker=True) | |
speech_config.speech_synthesis_voice_name='en-US-JennyNeural' | |
# speech_synthesizer = "" | |
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) | |
speech_synthesis_result = speech_synthesizer.speak_text_async(text).get() | |
# if speech_synthesis_result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | |
# print("Speech synthesized for text [{}]".format(text)) | |
# elif speech_synthesis_result.reason == speechsdk.ResultReason.Canceled: | |
# cancellation_details = speech_synthesis_result.cancellation_details | |
# print("Speech synthesis canceled: {}".format(cancellation_details.reason)) | |
# if cancellation_details.reason == speechsdk.CancellationReason.Error: | |
# if cancellation_details.error_details: | |
# print("Error details: {}".format(cancellation_details.error_details)) | |
# print("Did you set the speech resource key and region values?") | |
print("test") | |
return speech_synthesis_result | |
pass | |
def get_azure_access_token(): | |
azure_key = os.environ.get("SPEECH_KEY") | |
try: | |
response = requests.post( | |
"https://eastus.api.cognitive.microsoft.com/sts/v1.0/issuetoken", | |
headers={ | |
"Ocp-Apim-Subscription-Key": azure_key | |
} | |
) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
print(f"Error: {e}") | |
return None | |
# print (response.text) | |
return response.text | |
def text_to_speech_2(text): | |
global Audio_output | |
access_token = get_azure_access_token() | |
voice_name='en-US-AriaNeural' | |
if not access_token: | |
return None | |
try: | |
response = requests.post( | |
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1", | |
headers={ | |
"Authorization": f"Bearer {access_token}", | |
"Content-Type": "application/ssml+xml", | |
"X-MICROSOFT-OutputFormat": "riff-24khz-16bit-mono-pcm", | |
"User-Agent": "TextToSpeechApp", | |
}, | |
data=f""" | |
<speak version='1.0' xml:lang='en-US'> | |
<voice name='{voice_name}'> | |
{text} | |
</voice> | |
</speak> | |
""", | |
) | |
response.raise_for_status() | |
timestr = time.strftime("%Y%m%d-%H%M") | |
with open('sample-' + timestr + '.wav', 'wb') as audio: | |
audio.write(response.content) | |
print ("File Name ", audio.name) | |
# print (audio) | |
Audio_output.append(audio.name) | |
# return audio.name | |
return audio | |
except requests.exceptions.RequestException as e: | |
print(f"Error: {e}") | |
return None | |
Text2Sound_tool = Tool( | |
name = "Text_To_Sound_REST_API", | |
# func = Text2Sound, | |
func = text_to_speech_2, | |
description = "Useful when you need to convert text into sound file." | |
) | |
Text2Sound_tool2 = Tool( | |
name = "Text To Sound REST API", | |
# func = Text2Sound, | |
func = text_to_speech_2, | |
description = "Useful when you need to convert text into sound file." | |
) | |
Wikipedia = WikipediaAPIWrapper() | |
Netsearch = DuckDuckGoSearchRun() | |
Python_REPL = PythonREPL() | |
wikipedia_tool = Tool( | |
name = "Wikipedia_Search", | |
func = Wikipedia.run, | |
description = "Useful to search a topic, country or person when there is no availble information in vector database" | |
) | |
duckduckgo_tool = Tool( | |
name = "Duckduckgo_Internet_Search", | |
func = Netsearch.run, | |
description = "Useful to search information in internet when it is not available in other tools" | |
) | |
python_tool = Tool( | |
name = "Python_REPL", | |
func = Python_REPL.run, | |
description = "Useful when you need python to answer questions. You should input python code." | |
) | |
wikipedia_tool2 = Tool( | |
name = "Wikipedia Search", | |
func = Wikipedia.run, | |
description = "Useful to search a topic, country or person when there is no availble information in vector database" | |
) | |
duckduckgo_tool2 = Tool( | |
name = "Duckduckgo Internet Search", | |
func = Netsearch.run, | |
description = "Useful to search information in internet when it is not available in other tools" | |
) | |
python_tool2 = Tool( | |
name = "Python REPL", | |
func = Python_REPL.run, | |
description = "Useful when you need python to answer questions. You should input python code." | |
) | |
# tools = [DB_Search(), wikipedia_tool, duckduckgo_tool, python_tool] | |
os.environ["OPENAI_API_TYPE"] = "azure" | |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE") | |
os.environ["OPENAI_API_VERSION"] = os.getenv("OPENAI_API_VERSION") | |
# os.environ["OPENAI_API_VERSION"] = "2023-05-15" | |
username = os.getenv("username") | |
password = os.getenv("password") | |
SysLock = os.getenv("SysLock") # 0=unlock 1=lock | |
# deployment_name="Chattester" | |
chat = AzureChatOpenAI( | |
deployment_name=os.getenv("deployment_name"), | |
temperature=0, | |
) | |
llm = chat | |
llm_math = LLMMathChain.from_llm(llm) | |
math_tool = Tool( | |
name ='Calculator', | |
func = llm_math.run, | |
description ='Useful for when you need to answer questions about math.' | |
) | |
# openai | |
tools = [DB_Search(), duckduckgo_tool, python_tool, math_tool, Text2Sound_tool] | |
tools2 = [DB_Search2(), duckduckgo_tool2, wikipedia_tool2, python_tool2, math_tool, Text2Sound_tool2] | |
# tools = load_tools(["Vector Database Search","Wikipedia Search","Python REPL","llm-math"], llm=llm) | |
embeddings = OpenAIEmbeddings(deployment="model_embedding", chunk_size=15) | |
pinecone.init( | |
api_key = os.getenv("pinecone_api_key"), | |
environment='asia-southeast1-gcp-free', | |
# openapi_config=openapi_config | |
) | |
index_name = 'stla-baby' | |
index = pinecone.Index(index_name) | |
# index.delete(delete_all=True, namespace='') | |
# print(pinecone.whoami()) | |
# print(index.describe_index_stats()) | |
""" | |
Answer the following questions as best you can with details. | |
You can always use tools to convert text to sound. | |
You must always check internal vector database first and try to answer the question based on the information in internal vector database only. | |
Only when there is no information available from vector database, you can search information by using other tools. | |
You have access to the following tools: | |
Vector Database Search: This is the internal database to search information firstly. If information is found, it is trustful. | |
Duckduckgo Internet Search: Useful to search information in internet when it is not available in other tools. | |
Wikipedia Search: Useful to search a topic, country or person when there is no availble information in vector database | |
Python REPL: Useful when you need python to answer questions. You should input python code. | |
Calculator: Useful for when you need to answer questions about math. | |
Text To Sound: Useful when you need to convert text into sound file.""" | |
PREFIX = """Answer the following questions as best you can with detail information and explanation. | |
You can always use tools to convert text to sound. | |
You must always check vector database first and try to answer the question based on the information in vector database only. | |
Only when there is no information available from vector database, you can search information by using other tools. | |
You have access to the following tools:""" | |
FORMAT_INSTRUCTIONS = """Use the following format: | |
Question: the input question you must answer | |
Thought: you should always think about what to do | |
Action: the action to take, should be one of [Vector Database Search, Duckduckgo Internet Search, Python REPL, Calculator] | |
Action Input: the input to the action | |
Observation: the result of the action | |
(this Thought/Action/Action Input/Observation can repeat N times) | |
Thought: I now know the final answer | |
Final Answer: the final answer to the original input question""" | |
SUFFIX = """Begin! | |
Request: {input} | |
Thought:{agent_scratchpad}""" | |
SUFFIX2 = """Begin! | |
{chat_history} | |
Question: {input} | |
Thought:{agent_scratchpad}""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=PREFIX, | |
suffix=SUFFIX, | |
# suffix=SUFFIX2, | |
format_instructions=FORMAT_INSTRUCTIONS, | |
input_variables=["input", "agent_scratchpad"] | |
# input_variables=["input", "chat_history", "agent_scratchpad"] | |
) | |
prompt_openai = OpenAIMultiFunctionsAgent.create_prompt( | |
system_message = SystemMessage( | |
content="You are a helpful AI assistant."), | |
# extra_prompt_messages = [MessagesPlaceholder(variable_name="memory")], | |
) | |
input_variables=["input", "chat_history", "agent_scratchpad"] | |
agent_ZEROSHOT_REACT = initialize_agent(tools2, llm, | |
# agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose = True, | |
handle_parsing_errors = True, | |
max_iterations = int(os.getenv("max_iterations")), | |
early_stopping_method="generate", | |
agent_kwargs={ | |
'prefix': PREFIX, | |
'format_instructions': FORMAT_INSTRUCTIONS, | |
'suffix': SUFFIX, | |
# 'input_variables': input_variables, | |
}, | |
# input_variables = input_variables, | |
# agent_kwargs={ | |
# 'prompt': prompt, | |
# } | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
# llm_chain_openai = LLMChain(llm=llm, prompt=prompt_openai, verbose=True) | |
agent_core = ZeroShotAgent(llm_chain=llm_chain, tools=tools2, verbose=True) | |
agent_core_openai = OpenAIMultiFunctionsAgent(llm=llm, tools=tools, prompt=prompt_openai, verbose=True) | |
agent_ZEROSHOT_AGENT = AgentExecutor.from_agent_and_tools( | |
agent=agent_core, | |
tools=tools2, | |
verbose=True, | |
# memory=memory, | |
handle_parsing_errors = True, | |
max_iterations = int(os.getenv("max_iterations")), | |
early_stopping_method="generate", | |
) | |
agent_OPENAI_MULTI = AgentExecutor.from_agent_and_tools( | |
agent=agent_core_openai, | |
tools=tools, | |
verbose=True, | |
# memory=memory_openai, | |
handle_parsing_errors = True, | |
max_iterations = int(os.getenv("max_iterations")), | |
early_stopping_method="generate", | |
) | |
# agent.max_execution_time = int(os.getenv("max_iterations")) | |
# agent.handle_parsing_errors = True | |
# agent.early_stopping_method = "generate" | |
global agent | |
Choice = os.getenv("agent_type") | |
if Choice =='Zero Short Agent': | |
agent = agent_ZEROSHOT_AGENT | |
print("Set to:", Choice) | |
elif Choice =='Zero Short React': | |
agent = agent_ZEROSHOT_REACT | |
print("Set to:", Choice) | |
elif Choice =='OpenAI Multi': | |
agent = agent_OPENAI_MULTI | |
print("Set to:", Choice) | |
# agent = agent_ZEROSHOT_AGENT | |
# print(agent.agent.llm_chain.prompt.template) | |
# print(agent.agent.llm_chain.prompt) | |
global vectordb | |
# vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
global vectordb_p | |
vectordb_p = Pinecone.from_existing_index(index_name, embeddings) | |
# loader = DirectoryLoader('./documents', glob='**/*.txt') | |
# documents = loader.load() | |
# text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
# split_docs = text_splitter.split_documents(documents) | |
# print(split_docs) | |
# vectordb = Chroma.from_documents(split_docs, embeddings, persist_directory='db') | |
# question = "what is LCDV ?" | |
# rr = vectordb.similarity_search(query=question, k=4) | |
# vectordb.similarity_search(question) | |
# print(type(rr)) | |
# print(rr) | |
def chathmi(message, history1): | |
# response = "I don't know" | |
# print(message) | |
response, source = QAQuery_p(message) | |
time.sleep(0.3) | |
print(history1) | |
yield response | |
# yield history | |
def chathmi2(message, history): | |
global Audio_output | |
try: | |
output = agent.run(message) | |
time.sleep(0.3) | |
response = output | |
yield response | |
print ("response of chatbot:", response) | |
print ("\n") | |
# real_content = response[-1:] | |
# print("real_content", real_content) | |
try: | |
temp = response.split("(sandbox:/")[1] # (sandbox:/sample-20230805-0807.wav) | |
file_name = temp.split(")")[0] | |
print("file_name:", file_name) | |
dis_audio = [] | |
dis_audio.append(file_name) | |
# yield dis_audio | |
yield dis_audio | |
except: | |
pass | |
if len(Audio_output) > 0: | |
# time.sleep(0.5) | |
# yield Audio_output | |
Audio_output = [] | |
print("History: ", history) | |
print("-" * 20) | |
print("-" * 20) | |
except Exception as e: | |
print("error:", e) | |
# yield history | |
# chatbot = gr.Chatbot().style(color_map =("blue", "pink")) | |
# chatbot = gr.Chatbot(color_map =("blue", "pink")) | |
def func_upload_file(files, chat_history2): | |
global file_list_loaded | |
file_list_loaded = [] | |
for unit in files: | |
file_list_loaded.append(unit.name) | |
# file_list_loaded = files | |
print(file_list_loaded) | |
# print(chat_history) | |
# test_msg = ["Request Upload File into DB", "Operation Ongoing...."] | |
# chat_history.append(test_msg) | |
for file in files: | |
chat_history2 = chat_history2 + [((file.name,), None)] | |
yield chat_history2 | |
if os.getenv("SYS_Upload_Enable") == "1": | |
UpdateDb() | |
test_msg = ["Request Upload File into DB", "Operation Finished"] | |
chat_history2.append(test_msg) | |
yield chat_history2 | |
class Logger: | |
def __init__(self, filename): | |
self.terminal = sys.stdout | |
self.log = open(filename, "w") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
self.terminal.flush() | |
self.log.flush() | |
def isatty(self): | |
return False | |
sys.stdout = Logger("output.log") | |
def read_logs(): | |
sys.stdout.flush() | |
with open("output.log", "r") as f: | |
return f.read() | |
def SetAgent(Choice): | |
global agent | |
if Choice =='Zero Short Agent': | |
agent = agent_ZEROSHOT_AGENT | |
print("Set to:", Choice) | |
elif Choice =='Zero Short React': | |
agent = agent_ZEROSHOT_REACT | |
print("Set to:", Choice) | |
elif Choice =='OpenAI Multi': | |
agent = agent_OPENAI_MULTI | |
print("Set to:", Choice) | |
global record | |
record = [] | |
def LinkElement(chatbot_history): | |
''' | |
Link chatbot display output with other UI | |
''' | |
global record | |
if record != chatbot_history: | |
last_response = chatbot_history[-1:][1] | |
print("last response:", last_response) | |
record = chatbot_history | |
print(chatbot_history) | |
# print("link element test") | |
else: | |
print("From linkelement: ", chatbot_history) | |
pass | |
def chathmi3(message, history2): | |
global last_request | |
global Filename_Chatbot | |
print("Input Message:", message) | |
last_request = message | |
history2 = history2 + [(message, None)] | |
yield ["", history2] | |
try: | |
response = agent.run(message) | |
time.sleep(0.1) | |
history2 = history2 + [(None, response)] | |
yield ["", history2] | |
print ("response of chatbot:", response) | |
# real_content = response[-1:] | |
# print("real_content", real_content) | |
try: | |
temp = response.split("(sandbox:/")[1] # (sandbox:/sample-20230805-0807.wav) | |
file_name = temp.split(")")[0] | |
print("file_name:", file_name) | |
history2 = history2 + [(None, (file_name,))] | |
Filename_Chatbot = file_name | |
yield ["", history2] | |
except: | |
print("No need to add file in chatbot") | |
except Exception as e: | |
print("chathmi3 error:", e) | |
# history = history + [(message, None)] | |
print("History2: ", history2) | |
print("-" * 20) | |
print("-" * 20) | |
def fake(message, history4): | |
pass | |
def clearall(): | |
# global memory_openai | |
# global memory | |
# memory_openai.clear() | |
# memory.clear() | |
return [] | |
def retry(history3): | |
global last_request | |
print("last_request", last_request) | |
message = last_request | |
history3 = history3 + [(message, None)] | |
yield history3 | |
try: | |
response = agent.run(message) | |
time.sleep(0.1) | |
history3 = history3 + [(None, response)] | |
print ("response of chatbot:", response) | |
yield history3 | |
# real_content = response[-1:] | |
# print("real_content", real_content) | |
try: | |
temp = response.split("(sandbox:/")[1] # (sandbox:/sample-20230805-0807.wav) | |
file_name = temp.split(")")[0] | |
print("file_name:", file_name) | |
history3 = history3 + [(None, (file_name,))] | |
yield history3 | |
except: | |
print("No need to add file in chatbot") | |
except Exception as e: | |
print("chathmi3 error:", e) | |
# yield chathmi3(last_request, chatbot_history) | |
def display_input(message, history2): | |
global last_request | |
print("Input Message:", message) | |
last_request = message | |
history2 = history2 + [(message, None)] | |
return history2 | |
def Inference_Agent(history_inf): | |
global last_request | |
message = last_request | |
try: | |
response = agent.run(message) | |
time.sleep(0.1) | |
history_inf = history_inf + [(None, response)] | |
return ["",history_inf] | |
except Exception as e: | |
print("error:", e) | |
def ClearText(): | |
return "" | |
def playsound(): | |
global Filename_Chatbot | |
try: | |
if Filename_Chatbot.split(".")[1] == 'wav': | |
soundfilename = Filename_Chatbot | |
return soundfilename | |
except: | |
pass | |
def HMI_Runing(): | |
return [gr.update(visible=False), gr.update(visible=True)] | |
def HMI_Wait(): | |
return [gr.update(visible=True), gr.update(visible=False)] | |
with gr.Blocks() as demo: | |
# gr.Markdown("Start typing below and then click **SUBMIT** to see the output.") | |
# main = gr.ChatInterface( | |
# fake, | |
# title="STLA BABY - YOUR FRIENDLY GUIDE", | |
# description= "v0.3: Powered by MECH Core Team", | |
# ) | |
# main.textbox.submit(chathmi3, [main.textbox, main.chatbot], [main.textbox, main.chatbot]) | |
with gr.Column() as main2: | |
title = gr.Markdown("""# <center> STLA BABY - YOUR FRIENDLY GUIDE | |
<center> v0.4: Powered by MECH Core Team"""), | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
inputtext = gr.Textbox( | |
scale= 4, | |
label="", | |
placeholder = "Input Your Question", | |
show_label = False, | |
) | |
submit_button = gr.Button("SUBMIT", variant="primary", visible=True) | |
stop_button = gr.Button("STOP", variant='stop', visible=False) | |
with gr.Row(): | |
agentchoice = gr.Dropdown( | |
choices=['Zero Short Agent','Zero Short React','OpenAI Multi'], | |
label="SELECT AI AGENT", | |
scale= 2, | |
show_label = True, | |
value="OpenAI Multi", | |
) | |
voice_input = gr.Audio( | |
source="microphone", | |
type="filepath", | |
scale= 1, | |
label= "INPUT", | |
) | |
voice_output = gr.Audio( | |
source="microphone", | |
type="filepath", | |
scale= 1, | |
interactive=False, | |
autoplay= True, | |
label= "OUTPUT", | |
) | |
upload_button = gr.UploadButton("✡️ INGEST DB", file_count="multiple", scale= 0, variant="secondary") | |
upload_file_button = gr.UploadButton("📁 UPLOAD", file_count="single", scale= 0, variant="secondary") | |
retry_button = gr.Button("RETRY") | |
clear_button = gr.Button("CLEAR") | |
with gr.Accordion( | |
label = "LOGS", | |
open = False, | |
): | |
# logs = gr.Textbox() | |
frash_logs = gr.Button("Update Logs ...") | |
logs = gr.Textbox(max_lines = 25) | |
# upload_button.upload(func_upload_file, [upload_button, main.chatbot], main.chatbot) | |
clear_button.click(clearall, None, chatbot) | |
retry_button.click(retry, chatbot, chatbot) | |
# inf1 = inputtext.submit(chathmi3, [inputtext, chatbot], [inputtext, chatbot]).\ | |
# then(playsound, None, voice_output) | |
inf1 = inputtext.submit(HMI_Runing, None, [submit_button, stop_button]).\ | |
then(chathmi3, [inputtext, chatbot], [inputtext, chatbot]).\ | |
then(playsound, None, voice_output).\ | |
then(HMI_Wait, None, [submit_button, stop_button]) | |
inf3 = submit_button.click(HMI_Runing, None, [submit_button, stop_button]).\ | |
then(chathmi3, [inputtext, chatbot], [inputtext, chatbot]).\ | |
then(playsound, None, voice_output).\ | |
then(HMI_Wait, None, [submit_button, stop_button]) | |
# inf2 = inputtext.submit(display_input, [inputtext, chatbot], chatbot).\ | |
# then(Inference_Agent, chatbot, [inputtext, chatbot]) | |
stop_button.click(read_logs, None, logs, cancels=[inf1,inf3]).\ | |
then(HMI_Wait, None, [submit_button, stop_button]) | |
# stop_button.click(read_logs, None, logs, cancels=[inf2]) | |
upload_button.upload(func_upload_file, [upload_button, chatbot], chatbot) | |
agentchoice.change(SetAgent, agentchoice, None) | |
frash_logs.click(read_logs, None, logs) | |
# demo.load(read_logs, None, logs, every=1) | |
# demo = gr.Interface( | |
# chathmi, | |
# ["text", "state"], | |
# [chatbot, "state"], | |
# allow_flagging="never", | |
# ) | |
def CreatDb_P(): | |
global vectordb_p | |
index_name = 'stla-baby' | |
loader = DirectoryLoader('./documents', glob='**/*.txt') | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
split_docs = text_splitter.split_documents(documents) | |
print(split_docs) | |
pinecone.Index(index_name).delete(delete_all=True, namespace='') | |
vectordb_p = Pinecone.from_documents(split_docs, embeddings, index_name = "stla-baby") | |
print("Pinecone Updated Done") | |
print(index.describe_index_stats()) | |
def QAQuery_p(question: str): | |
global vectordb_p | |
# vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
retriever = vectordb_p.as_retriever() | |
retriever.search_kwargs['k'] = int(os.getenv("search_kwargs_k")) | |
# retriever.search_kwargs['fetch_k'] = 100 | |
qa = RetrievalQA.from_chain_type(llm=chat, chain_type="stuff", | |
retriever=retriever, return_source_documents = True, | |
verbose = True) | |
# qa = VectorDBQA.from_chain_type(llm=chat, chain_type="stuff", vectorstore=vectordb, return_source_documents=True) | |
# res = qa.run(question) | |
res = qa({"query": question}) | |
print("-" * 20) | |
print("Question:", question) | |
# print("Answer:", res) | |
print("Answer:", res['result']) | |
print("-" * 20) | |
print("Source:", res['source_documents']) | |
response = res['result'] | |
# response = res['source_documents'] | |
source = res['source_documents'] | |
return response, source | |
# def CreatDb(): | |
# ''' | |
# Funtion to creat chromadb DB based on with all docs | |
# ''' | |
# global vectordb | |
# loader = DirectoryLoader('./documents', glob='**/*.txt') | |
# documents = loader.load() | |
# text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
# split_docs = text_splitter.split_documents(documents) | |
# print(split_docs) | |
# vectordb = Chroma.from_documents(split_docs, embeddings, persist_directory='db') | |
# vectordb.persist() | |
def QAQuery(question: str): | |
global vectordb | |
# vectordb = Chroma(persist_directory='db', embedding_function=embeddings) | |
retriever = vectordb.as_retriever() | |
retriever.search_kwargs['k'] = 3 | |
# retriever.search_kwargs['fetch_k'] = 100 | |
qa = RetrievalQA.from_chain_type(llm=chat, chain_type="stuff", retriever=retriever, return_source_documents = True) | |
# qa = VectorDBQA.from_chain_type(llm=chat, chain_type="stuff", vectorstore=vectordb, return_source_documents=True) | |
# res = qa.run(question) | |
res = qa({"query": question}) | |
print("-" * 20) | |
print("Question:", question) | |
# print("Answer:", res) | |
print("Answer:", res['result']) | |
print("-" * 20) | |
print("Source:", res['source_documents']) | |
response = res['result'] | |
return response | |
# Used to complete content | |
def completeText(Text): | |
deployment_id="Chattester" | |
prompt = Text | |
completion = openai.Completion.create(deployment_id=deployment_id, | |
prompt=prompt, temperature=0) | |
print(f"{prompt}{completion['choices'][0]['text']}.") | |
# Used to chat | |
def chatText(Text): | |
deployment_id="Chattester" | |
conversation = [{"role": "system", "content": "You are a helpful assistant."}] | |
user_input = Text | |
conversation.append({"role": "user", "content": user_input}) | |
response = openai.ChatCompletion.create(messages=conversation, | |
deployment_id="Chattester") | |
print("\n" + response["choices"][0]["message"]["content"] + "\n") | |
if __name__ == '__main__': | |
# chatText("what is AI?") | |
# CreatDb() | |
# QAQuery("what is COFOR ?") | |
# CreatDb_P() | |
# QAQuery_p("what is GST ?") | |
if SysLock == "1": | |
demo.queue().launch(auth=(username, password), server_name="0.0.0.0", server_port=7860) | |
else: | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |
pass | |