eli02 commited on
Commit
8c23c78
·
1 Parent(s): 2d6fe2b

Remove unused parquet file and update requirements with specific package versions for better dependency management

Browse files
[openai_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet → [all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f94d381f4dfcff0bbf6bfa5c84def47794d1596e12e2204a2a4bb413fc25a05
3
- size 2257769
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ced650f23166f55939fb6dfec6df2fd7d83995a9db362a1a7460d36e6f3ab510
3
+ size 3118786
main.py CHANGED
@@ -7,11 +7,17 @@ from jose import JWTError, jwt
7
  from datetime import datetime, timedelta
8
  from openai import OpenAI
9
  from pathlib import Path
10
- from typing import List
 
 
 
11
  import pandas as pd
 
 
12
  import os
13
  import logging
14
-
 
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
@@ -19,6 +25,9 @@ logging.basicConfig(level=logging.INFO)
19
  # Initialize FastAPI app
20
  app = FastAPI()
21
 
 
 
 
22
  # JWT Configuration
23
  SECRET_KEY = os.environ.get("prime_auth", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be")
24
  REFRESH_SECRET_KEY = os.environ.get("prolonged_auth", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91")
@@ -26,27 +35,141 @@ ALGORITHM = "HS256"
26
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
27
  REFRESH_TOKEN_EXPIRE_DAYS = 7
28
 
29
- # OAuth2 scheme for token authentication
30
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
31
 
32
- # Load credentials from environment variables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def load_credentials():
34
  credentials = {}
35
- for i in range(1, 51): # Assuming you have 50 credentials
36
  username = os.environ.get(f"login_{i}")
37
  password = os.environ.get(f"password_{i}")
38
  if username and password:
39
  credentials[username] = password
40
  return credentials
41
 
42
- # Authenticate user and create token
43
  def authenticate_user(username: str, password: str):
44
  credentials_dict = load_credentials()
45
  if username in credentials_dict and credentials_dict[username] == password:
46
  return username
47
  return None
48
 
49
- # Create JWT token
50
  def create_token(data: dict, expires_delta: timedelta, secret_key: str):
51
  to_encode = data.copy()
52
  expire = datetime.utcnow() + expires_delta
@@ -54,7 +177,6 @@ def create_token(data: dict, expires_delta: timedelta, secret_key: str):
54
  encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
55
  return encoded_jwt
56
 
57
- # Verify JWT token
58
  def verify_token(token: str, secret_key: str):
59
  credentials_exception = HTTPException(
60
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -70,71 +192,20 @@ def verify_token(token: str, secret_key: str):
70
  raise credentials_exception
71
  return username
72
 
73
- # Verify access token
74
  def verify_access_token(token: str = Depends(oauth2_scheme)):
75
  return verify_token(token, SECRET_KEY)
76
 
77
- # Verify refresh token
78
- def verify_refresh_token(token: str):
79
- return verify_token(token, REFRESH_SECRET_KEY)
80
-
81
- # Load data from parquet file
82
- def load_data(database_file):
83
- df = pd.read_parquet(database_file)
84
-
85
- return df
86
-
87
- # Generate OpenAI embeddings
88
- def generate_openai_embeddings(client, text):
89
- response = client.embeddings.create(
90
- input=text,
91
- model="text-embedding-3-small"
92
- )
93
- return response.data[0].embedding
94
-
95
- # Compute cosine similarity
96
- def cosine_similarity(embedding_0, embedding_1):
97
- dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1))
98
- norm_0 = sum(a * a for a in embedding_0) ** 0.5
99
- norm_1 = sum(b * b for b in embedding_1) ** 0.5
100
- return dot_product / (norm_0 * norm_1)
101
-
102
- # Search query
103
- def search_query(client, query, df, n=3):
104
- embedding = generate_openai_embeddings(client, query)
105
- df['similarities'] = df.openai_embedding.apply(lambda x: cosine_similarity(x, embedding))
106
- res = df.sort_values('similarities', ascending=False).head(n)
107
- return res
108
-
109
- # Pydantic model for the query input
110
- class QueryInput(BaseModel):
111
- query: str
112
-
113
- # Pydantic model for the search result
114
- class SearchResult(BaseModel):
115
- text: str
116
- similarity: float
117
-
118
- # Pydantic model for the token response
119
- class TokenResponse(BaseModel):
120
- access_token: str
121
- refresh_token: str
122
- token_type: str
123
-
124
-
125
- # Root endpoint
126
  @app.get("/")
127
  def index() -> FileResponse:
 
128
  file_path = Path(__file__).parent / "static" / "index.html"
129
  return FileResponse(path=str(file_path), media_type="text/html")
130
 
131
- # Login endpoint to issue tokens
132
  @app.post("/login", response_model=TokenResponse)
133
  def login(form_data: OAuth2PasswordRequestForm = Depends()):
134
- logging.info("Login attempt for user: %s", form_data.username)
135
  username = authenticate_user(form_data.username, form_data.password)
136
  if not username:
137
- logging.warning("Authentication failed for user: %s", form_data.username)
138
  raise HTTPException(
139
  status_code=status.HTTP_401_UNAUTHORIZED,
140
  detail="Invalid username or password",
@@ -142,47 +213,150 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
142
  )
143
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
144
  refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
145
- access_token = create_token(data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY)
146
- refresh_token = create_token(data={"sub": username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY)
147
- logging.info("Tokens issued for user: %s", username)
148
- return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Refresh token endpoint
151
  @app.post("/refresh", response_model=TokenResponse)
152
- def refresh(refresh_token: str):
153
- username = verify_refresh_token(refresh_token)
154
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
155
- access_token = create_token(data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY)
156
- return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- # Search endpoint
159
  @app.post("/search", response_model=List[SearchResult])
160
- def search(
161
  query_input: QueryInput,
162
  username: str = Depends(verify_access_token),
163
  ):
164
- # Initialize OpenAI client
165
- client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
166
-
167
- # Load database
168
- database_file = Path(__file__).parent / "[openai_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet"
169
- df = load_data(database_file)
170
- logging.info("Database loaded successfully")
171
-
172
- # Perform search
173
- res = search_query(client, query_input.query, df, n=3)
174
-
175
- # Format results
176
- results = [
177
- SearchResult(text=row["ext"], similarity=row["similarities"])
178
- for _, row in res.iterrows()
179
- ]
180
 
181
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
 
183
  app.mount("/home", StaticFiles(directory="static", html=True), name="home")
184
 
185
- # Run the app
 
 
 
 
186
  if __name__ == "__main__":
187
  import uvicorn
188
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
7
  from datetime import datetime, timedelta
8
  from openai import OpenAI
9
  from pathlib import Path
10
+ from typing import List, Optional, Dict
11
+ from datasets import Dataset, load_dataset
12
+ from sentence_transformers import SentenceTransformer
13
+ from huggingface_hub import login
14
  import pandas as pd
15
+ import numpy as np
16
+ import torch as t
17
  import os
18
  import logging
19
+ from functools import lru_cache
20
+ from diskcache import Cache
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
 
25
  # Initialize FastAPI app
26
  app = FastAPI()
27
 
28
+ # Initialize disk cache
29
+ cache = Cache('./cache')
30
+
31
  # JWT Configuration
32
  SECRET_KEY = os.environ.get("prime_auth", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be")
33
  REFRESH_SECRET_KEY = os.environ.get("prolonged_auth", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91")
 
35
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
36
  REFRESH_TOKEN_EXPIRE_DAYS = 7
37
 
 
38
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
39
 
40
+ # Pydantic models
41
+ class QueryInput(BaseModel):
42
+ query: str
43
+
44
+ class SearchResult(BaseModel):
45
+ text: str
46
+ similarity: float
47
+ model_type: str
48
+
49
+ class TokenResponse(BaseModel):
50
+ access_token: str
51
+ refresh_token: str
52
+ token_type: str
53
+
54
+ class SaveInput(BaseModel):
55
+ user_type: str
56
+ username: str
57
+ query: str
58
+ retrieved_text: str
59
+ model_type: str
60
+ reaction: str
61
+
62
+ class SaveBatchInput(BaseModel):
63
+ items: List[SaveInput]
64
+
65
+ class RefreshRequest(BaseModel):
66
+ refresh_token: str
67
+
68
+ # Cache management
69
+ @lru_cache(maxsize=1)
70
+ def get_sentence_transformer():
71
+ """Load and cache the SentenceTransformer model with lru_cache"""
72
+ return SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cpu")
73
+
74
+ def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
75
+ """Try to get embeddings from cache"""
76
+ cache_key = f"{model_type}_{hash(text)}"
77
+ return cache.get(cache_key)
78
+
79
+ def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]):
80
+ """Store embeddings in cache"""
81
+ cache_key = f"{model_type}_{hash(text)}"
82
+ cache.set(cache_key, embeddings, expire=86400) # Cache for 24 hours
83
+
84
+ @lru_cache(maxsize=1)
85
+ def load_dataframe():
86
+ """Load and cache the parquet dataframe"""
87
+ database_file = Path(__file__).parent / "[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet"
88
+ return pd.read_parquet(database_file)
89
+
90
+ # Utility functions
91
+ def cosine_similarity(embedding_0, embedding_1):
92
+ dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1))
93
+ norm_0 = sum(a * a for a in embedding_0) ** 0.5
94
+ norm_1 = sum(b * b for b in embedding_1) ** 0.5
95
+ return dot_product / (norm_0 * norm_1)
96
+
97
+ def generate_embedding(model, text: str, model_type: str) -> List[float]:
98
+ # Try to get from cache first
99
+ cached_embedding = get_cached_embeddings(text, model_type)
100
+ if cached_embedding is not None:
101
+ return cached_embedding
102
+
103
+ # Generate new embedding if not in cache
104
+ if model_type == "all-mpnet-base-v2":
105
+ chunk_embedding = model.encode(
106
+ text,
107
+ convert_to_tensor=True
108
+ )
109
+ embedding = np.array(t.Tensor.cpu(chunk_embedding)).tolist()
110
+ elif model_type == "openai":
111
+ response = model.embeddings.create(
112
+ input=text,
113
+ model="text-embedding-3-small"
114
+ )
115
+ embedding = response.data[0].embedding
116
+
117
+ # Cache the new embedding
118
+ set_cached_embeddings(text, model_type, embedding)
119
+ return embedding
120
+
121
+ def search_query(client, st_model, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]:
122
+ # Generate embeddings for both models
123
+ mpnet_embedding = generate_embedding(st_model, query, "all-mpnet-base-v2")
124
+ openai_embedding = generate_embedding(client, query, "openai")
125
+
126
+ # Calculate similarities
127
+ df['mpnet_similarities'] = df.all_mpnet_embedding.apply(
128
+ lambda x: cosine_similarity(x, mpnet_embedding)
129
+ )
130
+ df['openai_similarities'] = df.openai_embedding.apply(
131
+ lambda x: cosine_similarity(x, openai_embedding)
132
+ )
133
+
134
+ # Get top results for each model
135
+ mpnet_results = df.nlargest(n, 'mpnet_similarities')
136
+ openai_results = df.nlargest(n, 'openai_similarities')
137
+
138
+ # Format results
139
+ results = []
140
+
141
+ for _, row in mpnet_results.iterrows():
142
+ results.append({
143
+ "text": row["ext"],
144
+ "similarity": float(row["mpnet_similarities"]),
145
+ "model_type": "all-mpnet-base-v2"
146
+ })
147
+
148
+ for _, row in openai_results.iterrows():
149
+ results.append({
150
+ "text": row["ext"],
151
+ "similarity": float(row["openai_similarities"]),
152
+ "model_type": "openai"
153
+ })
154
+
155
+ return results
156
+
157
+ # Authentication functions
158
  def load_credentials():
159
  credentials = {}
160
+ for i in range(1, 51):
161
  username = os.environ.get(f"login_{i}")
162
  password = os.environ.get(f"password_{i}")
163
  if username and password:
164
  credentials[username] = password
165
  return credentials
166
 
 
167
  def authenticate_user(username: str, password: str):
168
  credentials_dict = load_credentials()
169
  if username in credentials_dict and credentials_dict[username] == password:
170
  return username
171
  return None
172
 
 
173
  def create_token(data: dict, expires_delta: timedelta, secret_key: str):
174
  to_encode = data.copy()
175
  expire = datetime.utcnow() + expires_delta
 
177
  encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
178
  return encoded_jwt
179
 
 
180
  def verify_token(token: str, secret_key: str):
181
  credentials_exception = HTTPException(
182
  status_code=status.HTTP_401_UNAUTHORIZED,
 
192
  raise credentials_exception
193
  return username
194
 
 
195
  def verify_access_token(token: str = Depends(oauth2_scheme)):
196
  return verify_token(token, SECRET_KEY)
197
 
198
+ # Endpoints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  @app.get("/")
200
  def index() -> FileResponse:
201
+ """Serve the custom HTML page from the static directory"""
202
  file_path = Path(__file__).parent / "static" / "index.html"
203
  return FileResponse(path=str(file_path), media_type="text/html")
204
 
 
205
  @app.post("/login", response_model=TokenResponse)
206
  def login(form_data: OAuth2PasswordRequestForm = Depends()):
 
207
  username = authenticate_user(form_data.username, form_data.password)
208
  if not username:
 
209
  raise HTTPException(
210
  status_code=status.HTTP_401_UNAUTHORIZED,
211
  detail="Invalid username or password",
 
213
  )
214
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
215
  refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
216
+ access_token = create_token(
217
+ data={"sub": username},
218
+ expires_delta=access_token_expires,
219
+ secret_key=SECRET_KEY
220
+ )
221
+ refresh_token = create_token(
222
+ data={"sub": username},
223
+ expires_delta=refresh_token_expires,
224
+ secret_key=REFRESH_SECRET_KEY
225
+ )
226
+ return {
227
+ "access_token": access_token,
228
+ "refresh_token": refresh_token,
229
+ "token_type": "bearer"
230
+ }
231
 
 
232
  @app.post("/refresh", response_model=TokenResponse)
233
+ async def refresh(refresh_request: RefreshRequest):
234
+ """
235
+ Endpoint to refresh an access token using a valid refresh token.
236
+ Returns a new access token and the existing refresh token.
237
+ """
238
+ try:
239
+ # Verify the refresh token
240
+ username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
241
+
242
+ # Create new access token
243
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
244
+ access_token = create_token(
245
+ data={"sub": username},
246
+ expires_delta=access_token_expires,
247
+ secret_key=SECRET_KEY
248
+ )
249
+
250
+ return {
251
+ "access_token": access_token,
252
+ "refresh_token": refresh_request.refresh_token, # Return the same refresh token
253
+ "token_type": "bearer"
254
+ }
255
+
256
+ except JWTError:
257
+ raise HTTPException(
258
+ status_code=status.HTTP_401_UNAUTHORIZED,
259
+ detail="Could not validate credentials",
260
+ headers={"WWW-Authenticate": "Bearer"},
261
+ )
262
 
 
263
  @app.post("/search", response_model=List[SearchResult])
264
+ async def search(
265
  query_input: QueryInput,
266
  username: str = Depends(verify_access_token),
267
  ):
268
+ try:
269
+ # Initialize clients using cached functions
270
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
271
+ st_model = get_sentence_transformer()
272
+ df = load_dataframe()
273
+
274
+ # Perform search with both models
275
+ results = search_query(client, st_model, query_input.query, df, n=1)
276
+ return [SearchResult(**result) for result in results]
277
+
278
+ except Exception as e:
279
+ logging.error(f"Search error: {str(e)}")
280
+ raise HTTPException(
281
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
282
+ detail=f"Search failed: {str(e)}"
283
+ )
284
 
285
+ @app.post("/save")
286
+ async def save_data(
287
+ save_input: SaveBatchInput,
288
+ username: str = Depends(verify_access_token)
289
+ ):
290
+ try:
291
+ # Login to Hugging Face
292
+ hf_token = os.environ.get("al_ghazali_rag_retrieval_evaluation")
293
+ if not hf_token:
294
+ raise HTTPException(
295
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
296
+ detail="Hugging Face API token not found"
297
+ )
298
+ login(token=hf_token)
299
+
300
+ # Prepare data for saving
301
+ data = {
302
+ "user_type": [],
303
+ "username": [],
304
+ "query": [],
305
+ "retrieved_text": [],
306
+ "model_type": [],
307
+ "reaction": []
308
+ }
309
+
310
+ # Add each item to the data dict
311
+ for item in save_input.items:
312
+ data["user_type"].append(item.user_type)
313
+ data["username"].append(item.username)
314
+ data["query"].append(item.query)
315
+ data["retrieved_text"].append(item.retrieved_text)
316
+ data["model_type"].append(item.model_type)
317
+ data["reaction"].append(item.reaction)
318
+
319
+ try:
320
+ # Load existing dataset and merge
321
+ dataset = load_dataset(
322
+ "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation",
323
+ split="train"
324
+ )
325
+ existing_data = dataset.to_dict()
326
+
327
+ # Add new data
328
+ for key in data:
329
+ if key not in existing_data:
330
+ existing_data[key] = ["" if key in ["username", "model_type"] else None] * len(next(iter(existing_data.values())))
331
+ existing_data[key].extend(data[key])
332
+
333
+ except Exception as e:
334
+ logging.warning(f"Could not load existing dataset, creating new one: {str(e)}")
335
+ existing_data = data
336
+
337
+ # Create and push dataset
338
+ updated_dataset = Dataset.from_dict(existing_data)
339
+ updated_dataset.push_to_hub(
340
+ "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation"
341
+ )
342
+
343
+ return {"message": "Data saved successfully"}
344
+
345
+ except Exception as e:
346
+ logging.error(f"Save error: {str(e)}")
347
+ raise HTTPException(
348
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
349
+ detail=f"Failed to save data: {str(e)}"
350
+ )
351
 
352
+ # Make sure to keep the static files mounting
353
  app.mount("/home", StaticFiles(directory="static", html=True), name="home")
354
 
355
+ # Startup event to create cache directory if it doesn't exist
356
+ @app.on_event("startup")
357
+ async def startup_event():
358
+ os.makedirs("./cache", exist_ok=True)
359
+
360
  if __name__ == "__main__":
361
  import uvicorn
362
  uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,8 +1,14 @@
1
- fastapi
2
- uvicorn
3
- pandas
4
- openai
5
- python-dotenv
6
- pyarrow
7
- python-jose[cryptography]
8
- python-multipart
 
 
 
 
 
 
 
1
+ fastapi==0.109.2
2
+ uvicorn==0.27.1
3
+ python-jose==3.3.0
4
+ python-multipart==0.0.6 # Required for OAuth2 form handling
5
+ pydantic==2.6.1
6
+ openai==1.12.0
7
+ pandas==2.2.0
8
+ numpy==1.26.3
9
+ torch==2.1.2 # For sentence-transformers
10
+ sentence-transformers==2.3.1
11
+ datasets==2.17.0
12
+ huggingface-hub==0.20.3
13
+ diskcache==5.6.3
14
+ python-dotenv==1.0.1 # For environment variable management