Update main.py
Browse files
main.py
CHANGED
@@ -1,34 +1,38 @@
|
|
|
|
|
|
|
|
1 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
2 |
from langchain_community.vectorstores import Chroma
|
3 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
4 |
-
import PyPDF2
|
5 |
-
import os
|
6 |
-
import gradio as gr
|
7 |
-
import google.generativeai as genai
|
8 |
from langchain.chains import ConversationalRetrievalChain
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
|
11 |
print('程式初始化')
|
12 |
|
13 |
# 設定 Google API 金鑰
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
# 選擇模型
|
17 |
llm_model = 'gemini-1.5-flash'
|
18 |
embeddings_model = "models/embedding-001"
|
19 |
pdf_dir = 'data'
|
20 |
|
21 |
-
# 讀取 PDF
|
22 |
print('-' * 21, '讀取資料', '-' * 21)
|
23 |
docs = ""
|
24 |
for filename in os.listdir(pdf_dir):
|
25 |
if filename.endswith('.pdf'):
|
26 |
print(filename)
|
27 |
-
with open(os.path.join(pdf_dir, filename)
|
28 |
-
|
29 |
-
|
30 |
-
page = pdf_reader.pages[i]
|
31 |
-
docs += page.extract_text()
|
32 |
|
33 |
print('-' * 21, '讀取完成', '-' * 21)
|
34 |
|
@@ -39,40 +43,72 @@ if docs:
|
|
39 |
|
40 |
# 建立嵌入模型和檢索器
|
41 |
embeddings = GoogleGenerativeAIEmbeddings(
|
42 |
-
model=embeddings_model, google_api_key=
|
43 |
)
|
44 |
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
|
45 |
print('分割文本完成')
|
46 |
|
47 |
# 初始化 Gemini 模型
|
48 |
llm = ChatGoogleGenerativeAI(
|
49 |
-
model=llm_model, temperature=0.1, google_api_key=
|
50 |
)
|
51 |
print('模型載入完成')
|
52 |
|
|
|
|
|
53 |
|
54 |
-
#
|
55 |
chat_history = []
|
56 |
-
|
|
|
|
|
57 |
print('invoke 函數觸發')
|
58 |
-
|
59 |
-
system_prompt = (
|
60 |
-
"You are an assistant for question-answering tasks. "
|
61 |
-
"Use the following pieces of retrieved context to answer the question. "
|
62 |
-
|
63 |
-
)
|
64 |
-
#"If you don't know the answer, say that you don't know."
|
65 |
# 初始化 ConversationalRetrievalChain
|
66 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
67 |
llm=llm, retriever=retriever
|
68 |
)
|
69 |
|
70 |
-
#
|
71 |
response = qa_chain.invoke({"question": question, "chat_history": chat_history})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
# 更新 chat_history,保留上下文
|
74 |
-
# chat_history.append((question, response['answer']))
|
75 |
-
else:
|
76 |
-
response = 'No context!'
|
77 |
-
|
78 |
-
return response['answer']
|
|
|
1 |
+
from fastapi import FastAPI, Depends, HTTPException, status
|
2 |
+
from fastapi.security import APIKeyHeader
|
3 |
+
from pydantic import BaseModel
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain_community.vectorstores import Chroma
|
6 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
|
|
|
|
|
|
|
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
+
import pdfplumber
|
9 |
+
import os
|
10 |
+
|
11 |
+
# 初始化 FastAPI 應用
|
12 |
+
app = FastAPI()
|
13 |
|
14 |
print('程式初始化')
|
15 |
|
16 |
# 設定 Google API 金鑰
|
17 |
+
api_key = os.getenv("GOOGLE_API_KEY")
|
18 |
+
if not api_key:
|
19 |
+
raise ValueError("GOOGLE_API_KEY is not set")
|
20 |
+
genai.configure(api_key=api_key)
|
21 |
|
22 |
# 選擇模型
|
23 |
llm_model = 'gemini-1.5-flash'
|
24 |
embeddings_model = "models/embedding-001"
|
25 |
pdf_dir = 'data'
|
26 |
|
27 |
+
# 讀取 PDF 檔案並初始化資料
|
28 |
print('-' * 21, '讀取資料', '-' * 21)
|
29 |
docs = ""
|
30 |
for filename in os.listdir(pdf_dir):
|
31 |
if filename.endswith('.pdf'):
|
32 |
print(filename)
|
33 |
+
with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf:
|
34 |
+
for page in pdf.pages:
|
35 |
+
docs += page.extract_text()
|
|
|
|
|
36 |
|
37 |
print('-' * 21, '讀取完成', '-' * 21)
|
38 |
|
|
|
43 |
|
44 |
# 建立嵌入模型和檢索器
|
45 |
embeddings = GoogleGenerativeAIEmbeddings(
|
46 |
+
model=embeddings_model, google_api_key=api_key
|
47 |
)
|
48 |
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
|
49 |
print('分割文本完成')
|
50 |
|
51 |
# 初始化 Gemini 模型
|
52 |
llm = ChatGoogleGenerativeAI(
|
53 |
+
model=llm_model, temperature=0.1, google_api_key=api_key
|
54 |
)
|
55 |
print('模型載入完成')
|
56 |
|
57 |
+
else:
|
58 |
+
raise ValueError("No documents found in the 'data' directory.")
|
59 |
|
60 |
+
# 初始化 chat_history 作為全局變數
|
61 |
chat_history = []
|
62 |
+
|
63 |
+
# 定義 invoke 函數
|
64 |
+
def invoke(question: str):
|
65 |
print('invoke 函數觸發')
|
66 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
# 初始化 ConversationalRetrievalChain
|
68 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
69 |
llm=llm, retriever=retriever
|
70 |
)
|
71 |
|
72 |
+
# 呼叫 QA chain 並處理回應
|
73 |
response = qa_chain.invoke({"question": question, "chat_history": chat_history})
|
74 |
+
return response['answer']
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error during invoke: {e}")
|
78 |
+
return "An error occurred while processing the request."
|
79 |
+
|
80 |
+
|
81 |
+
# 定義請求資料模型
|
82 |
+
class InvokeRequest(BaseModel):
|
83 |
+
question: str
|
84 |
+
|
85 |
+
# API 金鑰設定
|
86 |
+
API_KEY_NAME = "X-API-Key"
|
87 |
+
API_KEY = os.getenv("API_KEY", "your_secret_api_key")
|
88 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME)
|
89 |
+
|
90 |
+
# 定義 API 金鑰驗證的依賴函數
|
91 |
+
def get_api_key(api_key: str = Depends(api_key_header)):
|
92 |
+
if api_key != API_KEY:
|
93 |
+
raise HTTPException(
|
94 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
95 |
+
detail="Invalid API Key",
|
96 |
+
)
|
97 |
+
return api_key
|
98 |
+
|
99 |
+
# 根端點,用於檢查服務是否運行正常
|
100 |
+
@app.get("/")
|
101 |
+
async def read_root():
|
102 |
+
return {"message": "Hello, World!"}
|
103 |
+
|
104 |
+
# 定義 POST 端點,讓使用者發送問題並獲得模型回應
|
105 |
+
@app.post("/invoke")
|
106 |
+
async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)):
|
107 |
+
result = invoke(request.question)
|
108 |
+
return {"result": result}
|
109 |
+
|
110 |
+
# 啟動應用程式:執行命令 `uvicorn main:app --reload`
|
111 |
+
if __name__ == "__main__":
|
112 |
+
import uvicorn
|
113 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|