Spaces:
Runtime error
Runtime error
nishantgaurav23
commited on
Commit
•
08ce49b
1
Parent(s):
06506f4
Upload 2 files
Browse files- embedding_processor.py +150 -0
- requirements.txt +8 -0
embedding_processor.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from typing import List, Callable
|
10 |
+
import glob
|
11 |
+
from tqdm import tqdm
|
12 |
+
import pickle
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import functools
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
# Force CPU device
|
18 |
+
torch.device('cpu')
|
19 |
+
|
20 |
+
# Logging configuration
|
21 |
+
LOGGING_CONFIG = {
|
22 |
+
'enabled': True,
|
23 |
+
'functions': {
|
24 |
+
'encode': True,
|
25 |
+
'store_embeddings': True,
|
26 |
+
'search': True,
|
27 |
+
'load_and_process_csvs': True
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
def log_function(func: Callable) -> Callable:
|
32 |
+
"""Decorator to log function inputs and outputs"""
|
33 |
+
@functools.wraps(func)
|
34 |
+
def wrapper(*args, **kwargs):
|
35 |
+
if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False):
|
36 |
+
return func(*args, **kwargs)
|
37 |
+
|
38 |
+
if args and hasattr(args[0], '__class__'):
|
39 |
+
class_name = args[0].__class__.__name__
|
40 |
+
else:
|
41 |
+
class_name = func.__module__
|
42 |
+
|
43 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
44 |
+
log_args = args[1:] if class_name != func.__module__ else args
|
45 |
+
|
46 |
+
def format_arg(arg):
|
47 |
+
if isinstance(arg, torch.Tensor):
|
48 |
+
return f"Tensor(shape={list(arg.shape)}, device={arg.device})"
|
49 |
+
elif isinstance(arg, list):
|
50 |
+
return f"List(len={len(arg)})"
|
51 |
+
elif isinstance(arg, str) and len(arg) > 100:
|
52 |
+
return f"String(len={len(arg)}): {arg[:100]}..."
|
53 |
+
return arg
|
54 |
+
|
55 |
+
formatted_args = [format_arg(arg) for arg in log_args]
|
56 |
+
formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()}
|
57 |
+
|
58 |
+
print(f"\n{'='*80}")
|
59 |
+
print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}")
|
60 |
+
print(f"INPUTS:")
|
61 |
+
print(f" args: {formatted_args}")
|
62 |
+
print(f" kwargs: {formatted_kwargs}")
|
63 |
+
|
64 |
+
result = func(*args, **kwargs)
|
65 |
+
|
66 |
+
formatted_result = format_arg(result)
|
67 |
+
print(f"OUTPUT:")
|
68 |
+
print(f" {formatted_result}")
|
69 |
+
print(f"{'='*80}\n")
|
70 |
+
|
71 |
+
return result
|
72 |
+
return wrapper
|
73 |
+
|
74 |
+
class SentenceTransformerRetriever:
|
75 |
+
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
|
76 |
+
with warnings.catch_warnings():
|
77 |
+
warnings.simplefilter("ignore")
|
78 |
+
self.device = torch.device("cpu")
|
79 |
+
self.model = SentenceTransformer(model_name, device="cpu")
|
80 |
+
self.doc_embeddings = None
|
81 |
+
self.cache_dir = cache_dir
|
82 |
+
self.cache_file = "embeddings.pkl"
|
83 |
+
os.makedirs(cache_dir, exist_ok=True)
|
84 |
+
|
85 |
+
def get_cache_path(self) -> str:
|
86 |
+
return os.path.join(self.cache_dir, self.cache_file)
|
87 |
+
|
88 |
+
@log_function
|
89 |
+
def save_cache(self, cache_data: dict):
|
90 |
+
cache_path = self.get_cache_path()
|
91 |
+
if os.path.exists(cache_path):
|
92 |
+
os.remove(cache_path)
|
93 |
+
with open(cache_path, 'wb') as f:
|
94 |
+
pickle.dump(cache_data, f)
|
95 |
+
print(f"Cache saved at: {cache_path}")
|
96 |
+
|
97 |
+
@log_function
|
98 |
+
def load_cache(self) -> dict:
|
99 |
+
cache_path = self.get_cache_path()
|
100 |
+
if os.path.exists(cache_path):
|
101 |
+
with open(cache_path, 'rb') as f:
|
102 |
+
print(f"Loading cache from: {cache_path}")
|
103 |
+
return pickle.load(f)
|
104 |
+
return None
|
105 |
+
|
106 |
+
@log_function
|
107 |
+
def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
108 |
+
embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
109 |
+
return F.normalize(embeddings, p=2, dim=1)
|
110 |
+
|
111 |
+
@log_function
|
112 |
+
def store_embeddings(self, embeddings: torch.Tensor):
|
113 |
+
self.doc_embeddings = embeddings
|
114 |
+
|
115 |
+
def process_data(data_folder: str):
|
116 |
+
retriever = SentenceTransformerRetriever()
|
117 |
+
documents = []
|
118 |
+
|
119 |
+
# Check cache first
|
120 |
+
cache_data = retriever.load_cache()
|
121 |
+
if cache_data is not None:
|
122 |
+
print("Using cached embeddings")
|
123 |
+
return cache_data
|
124 |
+
|
125 |
+
# Process CSV files
|
126 |
+
csv_files = glob.glob(os.path.join(data_folder, "*.csv"))
|
127 |
+
for csv_file in tqdm(csv_files, desc="Reading CSV files"):
|
128 |
+
try:
|
129 |
+
df = pd.read_csv(csv_file)
|
130 |
+
texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
|
131 |
+
documents.extend(texts)
|
132 |
+
except Exception as e:
|
133 |
+
print(f"Error processing file {csv_file}: {e}")
|
134 |
+
continue
|
135 |
+
|
136 |
+
# Generate embeddings
|
137 |
+
embeddings = retriever.encode(documents)
|
138 |
+
|
139 |
+
# Save cache
|
140 |
+
cache_data = {
|
141 |
+
'embeddings': embeddings,
|
142 |
+
'documents': documents
|
143 |
+
}
|
144 |
+
retriever.save_cache(cache_data)
|
145 |
+
|
146 |
+
return cache_data
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
data_folder = "ESPN_data"
|
150 |
+
process_data(data_folder)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
sentence-transformers
|
4 |
+
requests
|
5 |
+
pandas
|
6 |
+
numpy
|
7 |
+
tqdm
|
8 |
+
python-dotenv
|