Spaces:
Sleeping
Sleeping
# routes.py | |
import os | |
import shutil | |
import tempfile | |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form | |
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse | |
from trainer_manager import get_trainer | |
from config import CUSTOM_PROMPT | |
router = APIRouter() | |
def initialize_bot(): | |
trainer = get_trainer() | |
try: | |
bot_id = trainer.initialize_bot_id() | |
trainer.set_custom_prompt_template(bot_id, CUSTOM_PROMPT) | |
return InitializeBotResponse(bot_id=bot_id) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def add_document(document: DocumentPath): | |
trainer = get_trainer() | |
try: | |
trainer.add_document_from_path(document.data_path, document.bot_id) | |
return {"message": "Document added successfully."} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)): | |
trainer = get_trainer() | |
try: | |
# Save the uploaded file to a temporary directory | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp: | |
contents = await file.read() | |
tmp.write(contents) | |
tmp_path = tmp.name | |
# Add the document from the temporary file path | |
trainer.add_document_from_path(tmp_path, bot_id) | |
# Remove the temporary file | |
os.remove(tmp_path) | |
return {"message": "Document uploaded and added successfully."} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def create_bot(bot_id: str): | |
trainer = get_trainer() | |
try: | |
trainer.create_bot(bot_id) | |
return {"message": f"Bot {bot_id} created successfully."} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def new_chat(bot_id: str): | |
trainer = get_trainer() | |
try: | |
chat_id = trainer.new_chat(bot_id) | |
return NewChatResponse(chat_id=chat_id) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def send_query(query_request: QueryRequest): | |
trainer = get_trainer() | |
try: | |
response, web_sources = trainer.get_response( | |
query_request.query, query_request.bot_id, query_request.chat_id | |
) | |
return QueryResponse(response=response, web_sources=web_sources) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def list_chats(bot_id: str): | |
trainer = get_trainer() | |
try: | |
chats = trainer.list_chats(bot_id) | |
return chats | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def chat_history(chat_id: str, bot_id: str): | |
trainer = get_trainer() | |
try: | |
history = trainer.get_chat_by_id(chat_id=chat_id) | |
return history | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |