chitkenkhoi commited on
Commit
7936e19
·
1 Parent(s): e2180d0
Files changed (3) hide show
  1. Dockerfile +10 -0
  2. app.py +234 -0
  3. requirements.txt +16 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY ./app.py /code/app.py
9
+
10
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sentence_transformers import util, SentenceTransformer
6
+ import redis
7
+ import json
8
+ from typing import Dict, List
9
+ import google.generativeai as genai
10
+ from flask import Flask, request, jsonify, Response
11
+ import requests
12
+ from io import StringIO
13
+
14
+ # Initialize Flask app
15
+ app = Flask(__name__)
16
+
17
+ # Redis configuration
18
+ r = redis.Redis(
19
+ host='redis-12878.c1.ap-southeast-1-1.ec2.redns.redis-cloud.com',
20
+ port=12878,
21
+ db=0,
22
+ password="qKl6znBvULaveJhkjIjMr7RCwluJjjbH",
23
+ decode_responses=True
24
+ )
25
+
26
+ # Device configuration
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ # Load CSV from Google Drive
30
+ def load_csv_from_drive():
31
+ file_id = "1vU23pGS-kkpkUFNDl8BmuBUc2Am0966p"
32
+ url = f"https://drive.google.com/uc?id={file_id}"
33
+ response = requests.get(url)
34
+ csv_content = StringIO(response.text)
35
+ df = pd.read_csv(csv_content)[['text', 'embeddings']]
36
+
37
+ # Process embeddings
38
+ df["embeddings"] = df["embeddings"].apply(
39
+ lambda x: np.fromstring(x.strip("[]"), sep=",", dtype=np.float32)
40
+ )
41
+ return df
42
+
43
+ # Load data and initialize models
44
+ text_chunks_and_embedding_df = load_csv_from_drive()
45
+ pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records")
46
+ embeddings = torch.tensor(
47
+ np.vstack(text_chunks_and_embedding_df["embeddings"].values),
48
+ dtype=torch.float32
49
+ ).to(device)
50
+
51
+ # Initialize embedding model
52
+ embedding_model = SentenceTransformer(
53
+ model_name_or_path="keepitreal/vietnamese-sbert",
54
+ device=device
55
+ )
56
+
57
+ def store_conversation(conversation_id: str, q: str, a: str) -> None:
58
+ conversation_element = {
59
+ 'q': q,
60
+ 'a': a,
61
+ }
62
+ conversation_json = json.dumps(conversation_element)
63
+ r.lpush(f'conversation_{conversation_id}', conversation_json)
64
+ current_length = r.llen(f'conversation_{conversation_id}')
65
+ if current_length > 2:
66
+ r.rpop(f'conversation_{conversation_id}')
67
+
68
+ def retrieve_conversation(conversation_id):
69
+ conversation = r.lrange(f'conversation_{conversation_id}', 0, -1)
70
+ return [json.loads(c) for c in conversation]
71
+
72
+ def combine_vectors_method2(vector_weight_pairs):
73
+ weight_norm = np.sqrt(sum(weight**2 for _, weight in vector_weight_pairs))
74
+ combined_vector = np.zeros_like(vector_weight_pairs[0][0])
75
+
76
+ for vector, weight in vector_weight_pairs:
77
+ normalized_weight = weight / weight_norm
78
+ combined_vector += vector * normalized_weight
79
+
80
+ return combined_vector
81
+
82
+ def get_weighted_query(current_question: str, parsed_conversation: List[Dict]) -> np.ndarray:
83
+ current_vector = embedding_model.encode(current_question)
84
+ weighted_parts = [(current_vector, 1.0)]
85
+
86
+ if parsed_conversation:
87
+ context_string = " ".join(
88
+ f"{chat['q']} {chat['a']}" for chat in parsed_conversation
89
+ )
90
+ context_vector = embedding_model.encode(context_string)
91
+ similarity = util.pytorch_cos_sim(current_vector, context_vector)[0][0].item()
92
+ weight = 1.0 if similarity > 0.4 else 0.5
93
+ weighted_parts.append((context_vector, weight))
94
+
95
+ weighted_query_vector = combine_vectors_method2(weighted_parts)
96
+ weighted_query_vector = torch.from_numpy(weighted_query_vector).to(torch.float32)
97
+
98
+ # Normalize vector
99
+ norm = torch.norm(weighted_query_vector)
100
+ weighted_query_vector = weighted_query_vector / norm if norm > 0 else weighted_query_vector
101
+
102
+ return weighted_query_vector.numpy()
103
+
104
+ def retrieve_relevant_resources(query_vector, embeddings, similarity_threshold=0.5, n_resources_to_return=10):
105
+ query_embedding = torch.from_numpy(query_vector).to(torch.float32)
106
+ if len(query_embedding.shape) == 1:
107
+ query_embedding = query_embedding.unsqueeze(0)
108
+ query_embedding = query_embedding.cuda()
109
+
110
+ if embeddings.shape[1] != query_embedding.shape[1]:
111
+ query_embedding = torch.nn.functional.pad(
112
+ query_embedding,
113
+ (0, embeddings.shape[1] - query_embedding.shape[1])
114
+ )
115
+
116
+ query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1)
117
+ embeddings_normalized = torch.nn.functional.normalize(embeddings, p=2, dim=1)
118
+
119
+ cosine_scores = torch.matmul(query_embedding, embeddings_normalized.t())[0]
120
+
121
+ mask = cosine_scores >= similarity_threshold
122
+ filtered_scores = cosine_scores[mask]
123
+ filtered_indices = mask.nonzero().squeeze()
124
+
125
+ if len(filtered_scores) == 0:
126
+ return torch.tensor([]), torch.tensor([])
127
+
128
+ k = min(n_resources_to_return, len(filtered_scores))
129
+ scores, indices = torch.topk(filtered_scores, k=k)
130
+ final_indices = filtered_indices[indices]
131
+
132
+ return scores, final_indices
133
+
134
+ def prompt_formatter(query: str, context_items: List[Dict], history: List[Dict] = None, isFirst = False) -> str:
135
+ context = "- " + "\n- ".join([item["text"] for item in context_items])
136
+
137
+ history_str = ""
138
+ if history:
139
+ history_str = "\nLịch sử hội thoại:\n"
140
+ for qa in history:
141
+ history_str += f"Câu hỏi: {qa['q']}\n"
142
+ history_str += f"Trả lời: {qa['a']}\n"
143
+
144
+ if isFirst:
145
+ example = """
146
+ Đồng thời hãy thêm vào một dòng vào cuối câu trả lời của bạn, dòng đó sẽ là dòng nói về chủ đề mà người dùng đang hỏi.
147
+ Chủ đề nên càng ngắn gọn càng tốt (tối đa 7 từ).
148
+ Ví dụ:
149
+ Câu hỏi: "Trường đại học bách khoa thành lập vào năm nào?"
150
+ Ngữ cảnh có đề cập về trường đại học bách khoa thành lập vào năm 1957.
151
+ Trả lời: "Trường đại học bách khoa thành lập vào năm 1957. \n Chủ đề-123: Trường đại học Bách khoa"
152
+ """
153
+ else:
154
+ example = """
155
+ Ví dụ:
156
+ Câu hỏi: "Trường đại học bách khoa thành lập vào năm nào?"
157
+ Ngữ cảnh có đề cập về trường đại học bách khoa thành lập vào năm 1957.
158
+ Trả lời: "Trường đại học bách khoa thành lập vào năm 1957."
159
+ """
160
+
161
+ base_prompt = """Dựa trên các thông tin ngữ cảnh sau đây, hãy trả lời câu hỏi.
162
+ Hãy trích xuất các đoạn văn bản liên quan từ ngữ cảnh trước khi trả lời.
163
+ Chỉ trả lời câu hỏi, không cần giải thích quá trình suy luận.
164
+ Đảm bảo câu trả lời càng chi tiết và giải thích càng tốt.
165
+ Hãy trả lời đầy đủ, không được cắt ngắn câu trả lời.
166
+ Nếu câu trả lời quá dài, hãy chia thành các phần nhỏ và trả lời từng phần.
167
+ Nếu không có ngữ cảnh hoặc ngữ cảnh không cung cấp thông tin cần thiết hãy trả lời là "Mình không có dữ liệu về câu hỏi này" và không thêm bất cứ thứ gì.
168
+ Không được nhắc về từ "ngữ cảnh" trong câu trả lời. Tôi muốn câu trả lời của mình có đầy đủ chủ ngữ vị ngữ.
169
+ {example}
170
+
171
+ Ngữ cảnh:
172
+ {context}
173
+
174
+ Lịch sử cuộc hội thoại hiện tại:
175
+ {history}
176
+
177
+ Câu hỏi: {query}
178
+ Trả lời:"""
179
+
180
+ return base_prompt.format(context=context, history=history_str, query=query, example=example)
181
+
182
+ def ask_with_history_v3(query: str, conversation_id: str, isFirst):
183
+ parsed_conversation = retrieve_conversation(conversation_id)
184
+ weighted_query_vector = get_weighted_query(query, parsed_conversation)
185
+
186
+ threshold = 0.4
187
+ scores, indices = retrieve_relevant_resources(
188
+ query_vector=weighted_query_vector,
189
+ similarity_threshold=threshold,
190
+ embeddings=embeddings
191
+ )
192
+
193
+ scores_cpu = [score.cpu() for score in scores]
194
+ filtered_pairs = [(score, idx) for score, idx in zip(scores_cpu, indices) if score.item() >= threshold]
195
+
196
+ if filtered_pairs:
197
+ filtered_scores, filtered_indices = zip(*filtered_pairs)
198
+ context_items = [pages_and_chunks[i] for i in filtered_indices]
199
+ for i, item in enumerate(context_items):
200
+ item["score"] = filtered_scores[i]
201
+ else:
202
+ context_items = []
203
+
204
+ prompt = prompt_formatter(query=query, context_items=context_items, history=parsed_conversation, isFirst=isFirst)
205
+
206
+ genai.configure(api_key="AIzaSyDluIEKEhT1Dw2zx7SHEdmKipwBcYOmFQw")
207
+ model = genai.GenerativeModel("gemini-1.5-flash")
208
+ response = model.generate_content(prompt, stream=True)
209
+
210
+ for chunk in response:
211
+ yield chunk.text
212
+
213
+ store_conversation(conversation_id, query, response.text)
214
+
215
+ # API endpoints
216
+ @app.route('/ping', methods=['GET'])
217
+ def ping():
218
+ return jsonify("Service is running")
219
+
220
+ @app.route('/generate', methods=['POST'])
221
+ def generate_response():
222
+ query = request.json['query']
223
+ conversation_id = request.json['conversation_id']
224
+ isFirst = request.json['is_first']
225
+
226
+ def generate():
227
+ for token in ask_with_history_v3(query, conversation_id, isFirst):
228
+ yield token
229
+
230
+ return Response(generate(), mimetype='text/plain')
231
+
232
+ if __name__ == '__main__':
233
+ # Initialize data and models
234
+ app.run(host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ torchaudio==2.1.0
4
+ tqdm==4.66.1
5
+ sentence-transformers==2.2.2
6
+ accelerate==0.26.1
7
+ bitsandbytes==0.41.3
8
+ redis==5.0.1
9
+ google-generativeai==0.3.1
10
+ flask==3.0.0
11
+ pandas==2.1.3
12
+ numpy==1.26.2
13
+ transformers==4.36.2
14
+ huggingface-hub==0.19.4
15
+ spacy==3.7.2
16
+ regex==2023.10.3