from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import APIKeyHeader from pydantic import BaseModel from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI from langchain.chains import ConversationalRetrievalChain import pdfplumber import os import google.generativeai as genai from deep_translator import GoogleTranslator # 初始化 FastAPI 應用 app = FastAPI() print('程式初始化') # 設定 Google API 金鑰 api_key = os.getenv("GOOGLE_API_KEY") if not api_key: raise ValueError("GOOGLE_API_KEY is not set") genai.configure(api_key=api_key) # 選擇模型 llm_model = 'gemini-1.5-flash' embeddings_model = "models/embedding-001" pdf_dir = 'data' # 讀取 PDF 檔案並初始化資料 print('-' * 21, '讀取資料', '-' * 21) docs = "" for filename in os.listdir(pdf_dir): if filename.endswith('.pdf'): print(filename) with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf: for page in pdf.pages: docs += page.extract_text() print('-' * 21, '讀取完成', '-' * 21) # 分割文本 if docs: text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) texts = text_splitter.split_text(docs) # 建立嵌入模型和檢索器 embeddings = GoogleGenerativeAIEmbeddings( model=embeddings_model, google_api_key=api_key ) retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1}) print('分割文本完成') # 初始化 Gemini 模型 llm = ChatGoogleGenerativeAI( model=llm_model, temperature=0.1, google_api_key=api_key ) print('模型載入完成') else: raise ValueError("No documents found in the 'data' directory.") # 定義翻譯函數 def translate_to_english(text): return GoogleTranslator(source='auto', target='en').translate(text) def translate_to_chinese(text): return GoogleTranslator(source='auto', target='zh-TW').translate(text) # 初始化 chat_history 作為全局變數 chat_history = [] # 定義 invoke 函數 def invoke(question: str): print('invoke 函數觸發') try: # 初始化 ConversationalRetrievalChain qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever ) # 呼叫 QA chain 並處理回應 question = translate_to_english(question) response = qa_chain.invoke({"question": question, "chat_history": chat_history}) response = translate_to_chinese(response['answer']) # return response['answer'] return response except Exception as e: print(f"Error during invoke: {e}") return "An error occurred while processing the request." # 定義請求資料模型 class InvokeRequest(BaseModel): question: str # API 金鑰設定 API_KEY_NAME = "X-API-Key" API_KEY = os.getenv("API_KEY", "your_secret_api_key") api_key_header = APIKeyHeader(name=API_KEY_NAME) # 定義 API 金鑰驗證的依賴函數 def get_api_key(api_key: str = Depends(api_key_header)): if api_key != API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key", ) return api_key # 根端點,用於檢查服務是否運行正常 @app.get("/") async def read_root(): return {"message": "Hello, World!"} # 定義 POST 端點,讓使用者發送問題並獲得模型回應 @app.post("/invoke") async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)): result = invoke(request.question) return {"result": result} # 啟動應用程式:執行命令 `uvicorn main:app --reload` if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)