RAG-test2 / main.py
woonchen's picture
Update main.py
4fcb5d4 verified
raw
history blame
3.94 kB
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain.chains import ConversationalRetrievalChain
import pdfplumber
import os
import google.generativeai as genai
from deep_translator import GoogleTranslator
# 初始化 FastAPI 應用
app = FastAPI()
print('程式初始化')
# 設定 Google API 金鑰
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY is not set")
genai.configure(api_key=api_key)
# 選擇模型
llm_model = 'gemini-1.5-flash'
embeddings_model = "models/embedding-001"
pdf_dir = 'data'
# 讀取 PDF 檔案並初始化資料
print('-' * 21, '讀取資料', '-' * 21)
docs = ""
for filename in os.listdir(pdf_dir):
if filename.endswith('.pdf'):
print(filename)
with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf:
for page in pdf.pages:
docs += page.extract_text()
print('-' * 21, '讀取完成', '-' * 21)
# 分割文本
if docs:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_text(docs)
# 建立嵌入模型和檢索器
embeddings = GoogleGenerativeAIEmbeddings(
model=embeddings_model, google_api_key=api_key
)
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
print('分割文本完成')
# 初始化 Gemini 模型
llm = ChatGoogleGenerativeAI(
model=llm_model, temperature=0.1, google_api_key=api_key
)
print('模型載入完成')
else:
raise ValueError("No documents found in the 'data' directory.")
# 定義翻譯函數
def translate_to_english(text):
return GoogleTranslator(source='auto', target='en').translate(text)
def translate_to_chinese(text):
return GoogleTranslator(source='auto', target='zh-TW').translate(text)
# 初始化 chat_history 作為全局變數
chat_history = []
# 定義 invoke 函數
def invoke(question: str):
print('invoke 函數觸發')
try:
# 初始化 ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=retriever
)
# 呼叫 QA chain 並處理回應
question = translate_to_english(question)
response = qa_chain.invoke({"question": question, "chat_history": chat_history})
response = translate_to_chinese(response['answer'])
# return response['answer']
return response
except Exception as e:
print(f"Error during invoke: {e}")
return "An error occurred while processing the request."
# 定義請求資料模型
class InvokeRequest(BaseModel):
question: str
# API 金鑰設定
API_KEY_NAME = "X-API-Key"
API_KEY = os.getenv("API_KEY", "your_secret_api_key")
api_key_header = APIKeyHeader(name=API_KEY_NAME)
# 定義 API 金鑰驗證的依賴函數
def get_api_key(api_key: str = Depends(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key",
)
return api_key
# 根端點,用於檢查服務是否運行正常
@app.get("/")
async def read_root():
return {"message": "Hello, World!"}
# 定義 POST 端點,讓使用者發送問題並獲得模型回應
@app.post("/invoke")
async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)):
result = invoke(request.question)
return {"result": result}
# 啟動應用程式:執行命令 `uvicorn main:app --reload`
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)