|
from fastapi import FastAPI, Depends, HTTPException, status |
|
from fastapi.security import APIKeyHeader |
|
from pydantic import BaseModel |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI |
|
from langchain.chains import ConversationalRetrievalChain |
|
import pdfplumber |
|
import os |
|
import google.generativeai as genai |
|
from deep_translator import GoogleTranslator |
|
|
|
|
|
app = FastAPI() |
|
|
|
print('程式初始化') |
|
|
|
|
|
api_key = os.getenv("GOOGLE_API_KEY") |
|
if not api_key: |
|
raise ValueError("GOOGLE_API_KEY is not set") |
|
genai.configure(api_key=api_key) |
|
|
|
|
|
llm_model = 'gemini-1.5-flash' |
|
embeddings_model = "models/embedding-001" |
|
pdf_dir = 'data' |
|
|
|
|
|
print('-' * 21, '讀取資料', '-' * 21) |
|
docs = "" |
|
for filename in os.listdir(pdf_dir): |
|
if filename.endswith('.pdf'): |
|
print(filename) |
|
with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf: |
|
for page in pdf.pages: |
|
docs += page.extract_text() |
|
|
|
print('-' * 21, '讀取完成', '-' * 21) |
|
|
|
|
|
if docs: |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
texts = text_splitter.split_text(docs) |
|
|
|
|
|
embeddings = GoogleGenerativeAIEmbeddings( |
|
model=embeddings_model, google_api_key=api_key |
|
) |
|
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1}) |
|
print('分割文本完成') |
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
model=llm_model, temperature=0.1, google_api_key=api_key |
|
) |
|
print('模型載入完成') |
|
|
|
else: |
|
raise ValueError("No documents found in the 'data' directory.") |
|
|
|
|
|
|
|
def translate_to_english(text): |
|
return GoogleTranslator(source='auto', target='en').translate(text) |
|
|
|
def translate_to_chinese(text): |
|
return GoogleTranslator(source='auto', target='zh-TW').translate(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_history = [] |
|
|
|
|
|
def invoke(question: str): |
|
print('invoke 函數觸發') |
|
try: |
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, retriever=retriever |
|
) |
|
|
|
|
|
question = translate_to_english(question) |
|
response = qa_chain.invoke({"question": question, "chat_history": chat_history}) |
|
response = translate_to_chinese(response['answer']) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error during invoke: {e}") |
|
return "An error occurred while processing the request." |
|
|
|
|
|
|
|
class InvokeRequest(BaseModel): |
|
question: str |
|
|
|
|
|
API_KEY_NAME = "X-API-Key" |
|
API_KEY = os.getenv("API_KEY", "your_secret_api_key") |
|
api_key_header = APIKeyHeader(name=API_KEY_NAME) |
|
|
|
|
|
def get_api_key(api_key: str = Depends(api_key_header)): |
|
if api_key != API_KEY: |
|
raise HTTPException( |
|
status_code=status.HTTP_403_FORBIDDEN, |
|
detail="Invalid API Key", |
|
) |
|
return api_key |
|
|
|
|
|
@app.get("/") |
|
async def read_root(): |
|
return {"message": "Hello, World!"} |
|
|
|
|
|
@app.post("/invoke") |
|
async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)): |
|
result = invoke(request.question) |
|
return {"result": result} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |
|
|
|
|