|
import gradio as gr |
|
from openai import OpenAI |
|
from qdrant_client import QdrantClient |
|
|
|
import os |
|
|
|
qclient = QdrantClient( |
|
url="https://68106439-3d00-42df-880f-a5519695f677.us-east4-0.gcp.cloud.qdrant.io:6333", |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
) |
|
|
|
client = OpenAI( |
|
base_url="https://openrouter.ai/api/v1", |
|
api_key=os.getenv("OPENROUTER_API_KEY"), |
|
) |
|
|
|
def chat(prompt: str) -> str: |
|
message = client.chat.completions.create( |
|
model="anthropic/claude-3-haiku", |
|
messages=[ |
|
{"role": "user", "content": prompt} |
|
], |
|
).choices[0].message.content |
|
|
|
return message |
|
|
|
def question_answer(chat_history, question): |
|
import requests |
|
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-large-zh-v1.5" |
|
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"} |
|
|
|
payload = { |
|
"inputs": question, |
|
} |
|
|
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
e = response.json() |
|
|
|
search_result = qclient.search( |
|
collection_name="test_collection", query_vector=e, limit=20 |
|
) |
|
txt = '\n'.join([r.payload['text'] for r in search_result]) |
|
print(txt) |
|
|
|
prompt = f"现在你是一个资深的工程师管家,我将相关的信息已经从数据库中通过向量搜索给你了,如下\n{txt}\n, 根据这些信息回答我的这个问题\n{question}\n,"\ |
|
"尽量简短以及用数值去说明,如果并没有答案,请回答我不知道。" |
|
|
|
answer = chat(prompt) |
|
chat_history.append([question, answer]) |
|
return chat_history |
|
|
|
with gr.Blocks(css="""#chatbot { font-size: 14px; min-height: 1200; }""") as demo: |
|
gr.Markdown(f'<center><h3>Demo</h3></center>') |
|
with gr.Row(): |
|
with gr.Group(): |
|
|
|
|
|
question = gr.Textbox(label='Enter your question here') |
|
btn = gr.Button(value='Submit') |
|
|
|
with gr.Group(): |
|
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot") |
|
|
|
btn.click( |
|
question_answer, |
|
inputs=[chatbot, question], |
|
outputs=[chatbot], |
|
api_name="predict", |
|
) |
|
|
|
demo.launch(share=True) |