woonchen commited on
Commit
b67a4cc
·
verified ·
1 Parent(s): 7d16b85

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -30
main.py CHANGED
@@ -1,34 +1,38 @@
 
 
 
1
  from langchain.text_splitter import RecursiveCharacterTextSplitter
2
  from langchain_community.vectorstores import Chroma
3
  from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
4
- import PyPDF2
5
- import os
6
- import gradio as gr
7
- import google.generativeai as genai
8
  from langchain.chains import ConversationalRetrievalChain
9
- from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
10
 
11
  print('程式初始化')
12
 
13
  # 設定 Google API 金鑰
14
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 
 
 
15
 
16
  # 選擇模型
17
  llm_model = 'gemini-1.5-flash'
18
  embeddings_model = "models/embedding-001"
19
  pdf_dir = 'data'
20
 
21
- # 讀取 PDF 檔案
22
  print('-' * 21, '讀取資料', '-' * 21)
23
  docs = ""
24
  for filename in os.listdir(pdf_dir):
25
  if filename.endswith('.pdf'):
26
  print(filename)
27
- with open(os.path.join(pdf_dir, filename), 'rb') as pdf_file:
28
- pdf_reader = PyPDF2.PdfReader(pdf_file)
29
- for i in range(len(pdf_reader.pages)):
30
- page = pdf_reader.pages[i]
31
- docs += page.extract_text()
32
 
33
  print('-' * 21, '讀取完成', '-' * 21)
34
 
@@ -39,40 +43,72 @@ if docs:
39
 
40
  # 建立嵌入模型和檢索器
41
  embeddings = GoogleGenerativeAIEmbeddings(
42
- model=embeddings_model, google_api_key=os.getenv("GOOGLE_API_KEY")
43
  )
44
  retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
45
  print('分割文本完成')
46
 
47
  # 初始化 Gemini 模型
48
  llm = ChatGoogleGenerativeAI(
49
- model=llm_model, temperature=0.1, google_api_key=os.getenv("GOOGLE_API_KEY")
50
  )
51
  print('模型載入完成')
52
 
 
 
53
 
54
- # 定義 invoke 函數
55
  chat_history = []
56
- def invoke(question):
 
 
57
  print('invoke 函數觸發')
58
- if docs:
59
- system_prompt = (
60
- "You are an assistant for question-answering tasks. "
61
- "Use the following pieces of retrieved context to answer the question. "
62
-
63
- )
64
- #"If you don't know the answer, say that you don't know."
65
  # 初始化 ConversationalRetrievalChain
66
  qa_chain = ConversationalRetrievalChain.from_llm(
67
  llm=llm, retriever=retriever
68
  )
69
 
70
- # 调用链并传递 chat_history
71
  response = qa_chain.invoke({"question": question, "chat_history": chat_history})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # 更新 chat_history,保留上下文
74
- # chat_history.append((question, response['answer']))
75
- else:
76
- response = 'No context!'
77
-
78
- return response['answer']
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status
2
+ from fastapi.security import APIKeyHeader
3
+ from pydantic import BaseModel
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
  from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
 
 
 
 
7
  from langchain.chains import ConversationalRetrievalChain
8
+ import pdfplumber
9
+ import os
10
+
11
+ # 初始化 FastAPI 應用
12
+ app = FastAPI()
13
 
14
  print('程式初始化')
15
 
16
  # 設定 Google API 金鑰
17
+ api_key = os.getenv("GOOGLE_API_KEY")
18
+ if not api_key:
19
+ raise ValueError("GOOGLE_API_KEY is not set")
20
+ genai.configure(api_key=api_key)
21
 
22
  # 選擇模型
23
  llm_model = 'gemini-1.5-flash'
24
  embeddings_model = "models/embedding-001"
25
  pdf_dir = 'data'
26
 
27
+ # 讀取 PDF 檔案並初始化資料
28
  print('-' * 21, '讀取資料', '-' * 21)
29
  docs = ""
30
  for filename in os.listdir(pdf_dir):
31
  if filename.endswith('.pdf'):
32
  print(filename)
33
+ with pdfplumber.open(os.path.join(pdf_dir, filename)) as pdf:
34
+ for page in pdf.pages:
35
+ docs += page.extract_text()
 
 
36
 
37
  print('-' * 21, '讀取完成', '-' * 21)
38
 
 
43
 
44
  # 建立嵌入模型和檢索器
45
  embeddings = GoogleGenerativeAIEmbeddings(
46
+ model=embeddings_model, google_api_key=api_key
47
  )
48
  retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 1})
49
  print('分割文本完成')
50
 
51
  # 初始化 Gemini 模型
52
  llm = ChatGoogleGenerativeAI(
53
+ model=llm_model, temperature=0.1, google_api_key=api_key
54
  )
55
  print('模型載入完成')
56
 
57
+ else:
58
+ raise ValueError("No documents found in the 'data' directory.")
59
 
60
+ # 初始化 chat_history 作為全局變數
61
  chat_history = []
62
+
63
+ # 定義 invoke 函數
64
+ def invoke(question: str):
65
  print('invoke 函數觸發')
66
+ try:
 
 
 
 
 
 
67
  # 初始化 ConversationalRetrievalChain
68
  qa_chain = ConversationalRetrievalChain.from_llm(
69
  llm=llm, retriever=retriever
70
  )
71
 
72
+ # 呼叫 QA chain 並處理回應
73
  response = qa_chain.invoke({"question": question, "chat_history": chat_history})
74
+ return response['answer']
75
+
76
+ except Exception as e:
77
+ print(f"Error during invoke: {e}")
78
+ return "An error occurred while processing the request."
79
+
80
+
81
+ # 定義請求資料模型
82
+ class InvokeRequest(BaseModel):
83
+ question: str
84
+
85
+ # API 金鑰設定
86
+ API_KEY_NAME = "X-API-Key"
87
+ API_KEY = os.getenv("API_KEY", "your_secret_api_key")
88
+ api_key_header = APIKeyHeader(name=API_KEY_NAME)
89
+
90
+ # 定義 API 金鑰驗證的依賴函數
91
+ def get_api_key(api_key: str = Depends(api_key_header)):
92
+ if api_key != API_KEY:
93
+ raise HTTPException(
94
+ status_code=status.HTTP_403_FORBIDDEN,
95
+ detail="Invalid API Key",
96
+ )
97
+ return api_key
98
+
99
+ # 根端點,用於檢查服務是否運行正常
100
+ @app.get("/")
101
+ async def read_root():
102
+ return {"message": "Hello, World!"}
103
+
104
+ # 定義 POST 端點,讓使用者發送問題並獲得模型回應
105
+ @app.post("/invoke")
106
+ async def rag_endpoint(request: InvokeRequest, api_key: str = Depends(get_api_key)):
107
+ result = invoke(request.question)
108
+ return {"result": result}
109
+
110
+ # 啟動應用程式:執行命令 `uvicorn main:app --reload`
111
+ if __name__ == "__main__":
112
+ import uvicorn
113
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
114