learn-ai / app.py
dh-mc's picture
updated default questions
3ac9f6d
"""Main entrypoint for the app."""
import os
from threading import Thread
import time
from queue import Queue
from timeit import default_timer as timer
import gradio as gr
from anyio.from_thread import start_blocking_portal
from app_modules.init import app_init
from app_modules.llm_chat_chain import ChatChain
from app_modules.utils import print_llm_response, remove_extra_spaces
llm_loader, qa_chain = app_init()
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
chat_with_orca_2 = (
not using_openai and os.environ.get("USE_ORCA_2_PROMPT_TEMPLATE") == "true"
)
chat_history_enabled = (
not chat_with_orca_2 and os.environ.get("CHAT_HISTORY_ENABLED") == "true"
)
model = (
"OpenAI GPT-3.5"
if using_openai
else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
)
href = (
"https://platform.openai.com/docs/models/gpt-3-5"
if using_openai
else f"https://huggingface.co/{model}"
)
if chat_with_orca_2:
qa_chain = ChatChain(llm_loader)
name = "Orca-2"
else:
name = "AI Books"
title = f"Chat with {name}"
examples = (
["How to cook a fish?", "Who is the president of US now?"]
if chat_with_orca_2
else [
"What's Machine Learning?",
"What's Generative AI?",
"What's Difference in Differences?",
"What's Instrumental Variable?",
]
)
description = f"""\
<div align="left">
<p> Currently Running: <a href="{href}">{model}</a></p>
</div>
"""
def task(question, chat_history, q, result):
start = timer()
inputs = {"question": question, "chat_history": chat_history}
ret = qa_chain.call_chain(inputs, None, q)
end = timer()
print(f"Completed in {end - start:.3f}s")
print_llm_response(ret)
result.put(ret)
def predict(message, history):
print("predict:", message, history)
chat_history = []
if chat_history_enabled:
for element in history:
item = (element[0] or "", element[1] or "")
chat_history.append(item)
if not chat_history:
qa_chain.reset()
q = Queue()
result = Queue()
t = Thread(target=task, args=(message, chat_history, q, result))
t.start() # Starting the generation in a separate thread.
partial_message = ""
count = 2 if len(chat_history) > 0 else 1
while count > 0:
while q.empty():
print("nothing generated yet - retry in 0.5s")
time.sleep(0.5)
for next_token in llm_loader.streamer:
partial_message += next_token or ""
# partial_message = remove_extra_spaces(partial_message)
yield partial_message
if count == 2:
partial_message += "\n\n"
count -= 1
if not chat_with_orca_2:
partial_message += "\n\nSources:\n"
ret = result.get()
titles = []
for doc in ret["source_documents"]:
page = doc.metadata["page"] + 1
url = f"{doc.metadata['url']}#page={page}"
file_name = doc.metadata["source"].split("/")[-1]
title = f"{file_name} Page: {page}"
if title not in titles:
titles.append(title)
partial_message += f"1. [{title}]({url})\n"
yield partial_message
# Setting up the Gradio chat interface.
gr.ChatInterface(
predict,
title=title,
description=description,
examples=examples,
).launch(
share=share_gradio_app
) # Launching the web interface.