Spaces:
Build error
Build error
from langchain import PromptTemplate | |
from langchain_community.llms import LlamaCpp | |
from langchain.chains import RetrievalQA | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from fastapi import FastAPI, Request, Form, Response | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.encoders import jsonable_encoder | |
from qdrant_client import QdrantClient | |
from langchain_community.vectorstores import Qdrant | |
import os | |
import json | |
from huggingface_hub import hf_hub_download | |
from langchain.retrievers import EnsembleRetriever | |
from ingest import keyword_retriever | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
model_name = "aaditya/OpenBioLLM-Llama3-8B-GGUF" | |
model_file = "openbiollm-llama3-8b.Q5_K_M.gguf" | |
model_path = hf_hub_download(model_name, | |
filename=model_file, local_dir='./') | |
local_llm = "openbiollm-llama3-8b.Q5_K_M.gguf" | |
# Make sure the model path is correct for your system! | |
llm = LlamaCpp( | |
model_path= local_llm, | |
temperature=0.3, | |
# max_tokens=2048, | |
n_ctx=2048, | |
top_p=1 | |
) | |
print("LLM Initialized....") | |
prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
Only return the helpful answer. Answer must be detailed and well explained. | |
Helpful answer: | |
""" | |
embeddings = SentenceTransformerEmbeddings(model_name="medicalai/ClinicalBERT") | |
url = "http://localhost:6333" | |
client = QdrantClient( | |
url=url, prefer_grpc=False | |
) | |
db = Qdrant(client=client, embeddings=embeddings, collection_name="vector_db") | |
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) | |
retriever = db.as_retriever(search_kwargs={"k":1}) | |
ensemble_retriever = EnsembleRetriever(retrievers=[retriever, | |
keyword_retriever], | |
weights=[0.5, 0.5]) | |
async def read_root(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def get_response(query: str = Form(...)): | |
chain_type_kwargs = {"prompt": prompt} | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=ensemble_retriever, return_source_documents=True, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
response = qa(query) | |
print(response) | |
answer = response['result'] | |
source_document = response['source_documents'][0].page_content | |
doc = response['source_documents'][0].metadata['source'] | |
response_data = jsonable_encoder(json.dumps({"answer": answer, "source_document": source_document, "doc": doc})) | |
res = Response(response_data) | |
return res |