|
"""Main entrypoint for the app.""" |
|
|
|
import os |
|
from threading import Thread |
|
import time |
|
from queue import Queue |
|
from timeit import default_timer as timer |
|
|
|
import gradio as gr |
|
from anyio.from_thread import start_blocking_portal |
|
|
|
from app_modules.init import app_init |
|
from app_modules.llm_chat_chain import ChatChain |
|
from app_modules.utils import print_llm_response, remove_extra_spaces |
|
|
|
llm_loader, qa_chain = app_init() |
|
|
|
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true" |
|
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai" |
|
chat_with_orca_2 = ( |
|
not using_openai and os.environ.get("USE_ORCA_2_PROMPT_TEMPLATE") == "true" |
|
) |
|
chat_history_enabled = ( |
|
not chat_with_orca_2 and os.environ.get("CHAT_HISTORY_ENABLED") == "true" |
|
) |
|
|
|
model = ( |
|
"OpenAI GPT-3.5" |
|
if using_openai |
|
else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH") |
|
) |
|
href = ( |
|
"https://platform.openai.com/docs/models/gpt-3-5" |
|
if using_openai |
|
else f"https://huggingface.co/{model}" |
|
) |
|
|
|
if chat_with_orca_2: |
|
qa_chain = ChatChain(llm_loader) |
|
name = "Orca-2" |
|
else: |
|
name = "AI Books" |
|
|
|
title = f"Chat with {name}" |
|
examples = ( |
|
["How to cook a fish?", "Who is the president of US now?"] |
|
if chat_with_orca_2 |
|
else [ |
|
"What's Machine Learning?", |
|
"What's Generative AI?", |
|
"What's Difference in Differences?", |
|
"What's Instrumental Variable?", |
|
] |
|
) |
|
description = f"""\ |
|
<div align="left"> |
|
<p> Currently Running: <a href="{href}">{model}</a></p> |
|
</div> |
|
""" |
|
|
|
|
|
def task(question, chat_history, q, result): |
|
start = timer() |
|
inputs = {"question": question, "chat_history": chat_history} |
|
ret = qa_chain.call_chain(inputs, None, q) |
|
end = timer() |
|
|
|
print(f"Completed in {end - start:.3f}s") |
|
print_llm_response(ret) |
|
|
|
result.put(ret) |
|
|
|
|
|
def predict(message, history): |
|
print("predict:", message, history) |
|
|
|
chat_history = [] |
|
if chat_history_enabled: |
|
for element in history: |
|
item = (element[0] or "", element[1] or "") |
|
chat_history.append(item) |
|
|
|
if not chat_history: |
|
qa_chain.reset() |
|
|
|
q = Queue() |
|
result = Queue() |
|
t = Thread(target=task, args=(message, chat_history, q, result)) |
|
t.start() |
|
|
|
partial_message = "" |
|
count = 2 if len(chat_history) > 0 else 1 |
|
|
|
while count > 0: |
|
while q.empty(): |
|
print("nothing generated yet - retry in 0.5s") |
|
time.sleep(0.5) |
|
|
|
for next_token in llm_loader.streamer: |
|
partial_message += next_token or "" |
|
|
|
yield partial_message |
|
|
|
if count == 2: |
|
partial_message += "\n\n" |
|
|
|
count -= 1 |
|
|
|
if not chat_with_orca_2: |
|
partial_message += "\n\nSources:\n" |
|
ret = result.get() |
|
titles = [] |
|
for doc in ret["source_documents"]: |
|
page = doc.metadata["page"] + 1 |
|
url = f"{doc.metadata['url']}#page={page}" |
|
file_name = doc.metadata["source"].split("/")[-1] |
|
title = f"{file_name} Page: {page}" |
|
if title not in titles: |
|
titles.append(title) |
|
partial_message += f"1. [{title}]({url})\n" |
|
|
|
yield partial_message |
|
|
|
|
|
|
|
gr.ChatInterface( |
|
predict, |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
).launch( |
|
share=share_gradio_app |
|
) |
|
|