|
"""Main entrypoint for the app.""" |
|
import json |
|
import os |
|
from timeit import default_timer as timer |
|
from typing import List, Optional |
|
|
|
from lcserve import serving |
|
from pydantic import BaseModel |
|
|
|
from app_modules.init import app_init |
|
from app_modules.llm_chat_chain import ChatChain |
|
from app_modules.utils import print_llm_response |
|
|
|
llm_loader, qa_chain = app_init() |
|
|
|
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true" |
|
|
|
uuid_to_chat_chain_mapping = dict() |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
"""Chat response schema.""" |
|
|
|
token: Optional[str] = None |
|
error: Optional[str] = None |
|
sourceDocs: Optional[List] = None |
|
|
|
|
|
@serving(websocket=True) |
|
def chat( |
|
question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs |
|
) -> str: |
|
print(f"uuid: {uuid}") |
|
|
|
streaming_handler = kwargs.get("streaming_handler") |
|
if uuid is None: |
|
chat_history = [] |
|
if chat_history_enabled: |
|
for element in history: |
|
item = (element[0] or "", element[1] or "") |
|
chat_history.append(item) |
|
|
|
start = timer() |
|
result = qa_chain.call_chain( |
|
{"question": question, "chat_history": chat_history}, streaming_handler |
|
) |
|
end = timer() |
|
print(f"Completed in {end - start:.3f}s") |
|
|
|
resp = ChatResponse(sourceDocs=result["source_documents"]) |
|
|
|
return json.dumps(resp.dict()) |
|
else: |
|
if uuid in uuid_to_chat_chain_mapping: |
|
chat = uuid_to_chat_chain_mapping[uuid] |
|
else: |
|
chat = ChatChain(llm_loader) |
|
uuid_to_chat_chain_mapping[uuid] = chat |
|
result = chat.call_chain({"question": question}, streaming_handler) |
|
print(f"result: {result}") |
|
|
|
resp = ChatResponse(sourceDocs=[]) |
|
return json.dumps(resp.dict()) |
|
|
|
|
|
if __name__ == "__main__": |
|
print_llm_response(json.loads(chat("What's deep learning?", []))) |
|
|