File size: 3,938 Bytes
b67a4cc
 
 
e6441b6
 
 
 
b67a4cc
 
bd353c4
4fcb5d4
b67a4cc
 
 
f91846a
e6441b6
f91846a
e6441b6
b67a4cc
 
 
 
98dc21e
e6441b6
 
 
 
98dc21e
b67a4cc
e6441b6
 
 
 
 
b67a4cc
 
 
2cf7795
e6441b6
98dc21e
e6441b6
 
 
 
f23a640
e6441b6
 
b67a4cc
e6441b6
 
 
98dc21e
e6441b6
 
b67a4cc
e6441b6
 
55b1be9
b67a4cc
 
7d16b85
4fcb5d4
 
 
 
 
 
 
 
 
 
 
 
b67a4cc
56a4e1a
b67a4cc
 
 
e6441b6
b67a4cc
e6441b6
 
 
98dc21e
e6441b6
b67a4cc
4fcb5d4
e6441b6
4fcb5d4
 
 
b67a4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6441b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain.chains import ConversationalRetrievalChain
import pdfplumber
import os
import google.generativeai as genai
from deep_translator import GoogleTranslator

# 初始化 FastAPI 應用
app = FastAPI()

print('程式初始化')

# 設定 Google API 金鑰
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
    raise ValueError("GOOGLE_API_KEY is not set")
genai.configure(api_key=api_key)

# 選擇模型
llm_model = 'gemini-1.5-flash'
embeddings_model = "models/embedding-001"
pdf_dir = 'data'

# 讀取 PDF 檔案並初始化資料
print('-' * 21, '讀取資料', '-' * 21)
docs = ""
for filename in os.listdir(pdf_dir):
    if filename.endswith('.pdf'):
        print(filename)
        with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf:
            for page in pdf.pages:
                docs += page.extract_text()

print('-' * 21, '讀取完成', '-' * 21)

# 分割文本
if docs:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    texts = text_splitter.split_text(docs)

    # 建立嵌入模型和檢索器
    embeddings = GoogleGenerativeAIEmbeddings(
        model=embeddings_model, google_api_key=api_key
    )
    retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
    print('分割文本完成')

    # 初始化 Gemini 模型
    llm = ChatGoogleGenerativeAI(
        model=llm_model, temperature=0.1, google_api_key=api_key
    )
    print('模型載入完成')

else:
    raise ValueError("No documents found in the 'data' directory.")


# 定義翻譯函數
def translate_to_english(text):
    return GoogleTranslator(source='auto', target='en').translate(text)

def translate_to_chinese(text):
    return GoogleTranslator(source='auto', target='zh-TW').translate(text)





# 初始化 chat_history 作為全局變數
chat_history = []

# 定義 invoke 函數
def invoke(question: str):
    print('invoke 函數觸發')
    try:
        # 初始化 ConversationalRetrievalChain
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm=llm, retriever=retriever
        )

        # 呼叫 QA chain 並處理回應
        question = translate_to_english(question)
        response = qa_chain.invoke({"question": question, "chat_history": chat_history})
        response = translate_to_chinese(response['answer'])
        # return response['answer']
        return response

    except Exception as e:
        print(f"Error during invoke: {e}")
        return "An error occurred while processing the request."


# 定義請求資料模型
class InvokeRequest(BaseModel):
    question: str

# API 金鑰設定
API_KEY_NAME = "X-API-Key"
API_KEY = os.getenv("API_KEY", "your_secret_api_key")
api_key_header = APIKeyHeader(name=API_KEY_NAME)

# 定義 API 金鑰驗證的依賴函數
def get_api_key(api_key: str = Depends(api_key_header)):
    if api_key != API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Invalid API Key",
        )
    return api_key

# 根端點,用於檢查服務是否運行正常
@app.get("/")
async def read_root():
    return {"message": "Hello, World!"}

# 定義 POST 端點,讓使用者發送問題並獲得模型回應
@app.post("/invoke")
async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)):
    result = invoke(request.question)
    return {"result": result}

# 啟動應用程式:執行命令 `uvicorn main:app --reload`
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)