nishantgaurav23 commited on
Commit
08ce49b
1 Parent(s): 06506f4

Upload 2 files

Browse files
Files changed (2) hide show
  1. embedding_processor.py +150 -0
  2. requirements.txt +8 -0
embedding_processor.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore", category=UserWarning)
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from sentence_transformers import SentenceTransformer
9
+ from typing import List, Callable
10
+ import glob
11
+ from tqdm import tqdm
12
+ import pickle
13
+ import torch.nn.functional as F
14
+ import functools
15
+ from datetime import datetime
16
+
17
+ # Force CPU device
18
+ torch.device('cpu')
19
+
20
+ # Logging configuration
21
+ LOGGING_CONFIG = {
22
+ 'enabled': True,
23
+ 'functions': {
24
+ 'encode': True,
25
+ 'store_embeddings': True,
26
+ 'search': True,
27
+ 'load_and_process_csvs': True
28
+ }
29
+ }
30
+
31
+ def log_function(func: Callable) -> Callable:
32
+ """Decorator to log function inputs and outputs"""
33
+ @functools.wraps(func)
34
+ def wrapper(*args, **kwargs):
35
+ if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False):
36
+ return func(*args, **kwargs)
37
+
38
+ if args and hasattr(args[0], '__class__'):
39
+ class_name = args[0].__class__.__name__
40
+ else:
41
+ class_name = func.__module__
42
+
43
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
44
+ log_args = args[1:] if class_name != func.__module__ else args
45
+
46
+ def format_arg(arg):
47
+ if isinstance(arg, torch.Tensor):
48
+ return f"Tensor(shape={list(arg.shape)}, device={arg.device})"
49
+ elif isinstance(arg, list):
50
+ return f"List(len={len(arg)})"
51
+ elif isinstance(arg, str) and len(arg) > 100:
52
+ return f"String(len={len(arg)}): {arg[:100]}..."
53
+ return arg
54
+
55
+ formatted_args = [format_arg(arg) for arg in log_args]
56
+ formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()}
57
+
58
+ print(f"\n{'='*80}")
59
+ print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}")
60
+ print(f"INPUTS:")
61
+ print(f" args: {formatted_args}")
62
+ print(f" kwargs: {formatted_kwargs}")
63
+
64
+ result = func(*args, **kwargs)
65
+
66
+ formatted_result = format_arg(result)
67
+ print(f"OUTPUT:")
68
+ print(f" {formatted_result}")
69
+ print(f"{'='*80}\n")
70
+
71
+ return result
72
+ return wrapper
73
+
74
+ class SentenceTransformerRetriever:
75
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
76
+ with warnings.catch_warnings():
77
+ warnings.simplefilter("ignore")
78
+ self.device = torch.device("cpu")
79
+ self.model = SentenceTransformer(model_name, device="cpu")
80
+ self.doc_embeddings = None
81
+ self.cache_dir = cache_dir
82
+ self.cache_file = "embeddings.pkl"
83
+ os.makedirs(cache_dir, exist_ok=True)
84
+
85
+ def get_cache_path(self) -> str:
86
+ return os.path.join(self.cache_dir, self.cache_file)
87
+
88
+ @log_function
89
+ def save_cache(self, cache_data: dict):
90
+ cache_path = self.get_cache_path()
91
+ if os.path.exists(cache_path):
92
+ os.remove(cache_path)
93
+ with open(cache_path, 'wb') as f:
94
+ pickle.dump(cache_data, f)
95
+ print(f"Cache saved at: {cache_path}")
96
+
97
+ @log_function
98
+ def load_cache(self) -> dict:
99
+ cache_path = self.get_cache_path()
100
+ if os.path.exists(cache_path):
101
+ with open(cache_path, 'rb') as f:
102
+ print(f"Loading cache from: {cache_path}")
103
+ return pickle.load(f)
104
+ return None
105
+
106
+ @log_function
107
+ def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
108
+ embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
109
+ return F.normalize(embeddings, p=2, dim=1)
110
+
111
+ @log_function
112
+ def store_embeddings(self, embeddings: torch.Tensor):
113
+ self.doc_embeddings = embeddings
114
+
115
+ def process_data(data_folder: str):
116
+ retriever = SentenceTransformerRetriever()
117
+ documents = []
118
+
119
+ # Check cache first
120
+ cache_data = retriever.load_cache()
121
+ if cache_data is not None:
122
+ print("Using cached embeddings")
123
+ return cache_data
124
+
125
+ # Process CSV files
126
+ csv_files = glob.glob(os.path.join(data_folder, "*.csv"))
127
+ for csv_file in tqdm(csv_files, desc="Reading CSV files"):
128
+ try:
129
+ df = pd.read_csv(csv_file)
130
+ texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
131
+ documents.extend(texts)
132
+ except Exception as e:
133
+ print(f"Error processing file {csv_file}: {e}")
134
+ continue
135
+
136
+ # Generate embeddings
137
+ embeddings = retriever.encode(documents)
138
+
139
+ # Save cache
140
+ cache_data = {
141
+ 'embeddings': embeddings,
142
+ 'documents': documents
143
+ }
144
+ retriever.save_cache(cache_data)
145
+
146
+ return cache_data
147
+
148
+ if __name__ == "__main__":
149
+ data_folder = "ESPN_data"
150
+ process_data(data_folder)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ sentence-transformers
4
+ requests
5
+ pandas
6
+ numpy
7
+ tqdm
8
+ python-dotenv