prithivida
commited on
Update torch code
Browse files
README.md
CHANGED
@@ -140,17 +140,20 @@ sparse_rep = expander.expand(
|
|
140 |
import torch
|
141 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
142 |
|
|
|
143 |
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
|
144 |
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
|
|
|
145 |
|
146 |
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
|
147 |
|
148 |
inputs = tokenizer(sentence, return_tensors='pt')
|
|
|
149 |
input_ids = inputs['input_ids']
|
|
|
150 |
attention_mask = inputs['attention_mask']
|
151 |
|
152 |
outputs = model(**inputs)
|
153 |
-
print(outputs.logits.shape)
|
154 |
|
155 |
logits, attention_mask = outputs.logits, attention_mask
|
156 |
relu_log = torch.log(1 + torch.relu(logits))
|
@@ -158,19 +161,18 @@ weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
|
158 |
max_val, _ = torch.max(weighted_log, dim=1)
|
159 |
vector = max_val.squeeze()
|
160 |
|
|
|
161 |
cols = vector.nonzero().squeeze().cpu().tolist()
|
|
|
162 |
weights = vector[cols].cpu().tolist()
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
# Sort the dictionary by weights in descending order
|
171 |
-
sorted_token_weight_dict = {k: v for k, v in sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True) if v > 0}
|
172 |
-
print(sorted_token_weight_dict)
|
173 |
|
|
|
174 |
```
|
175 |
|
176 |
## BEIR Zeroshot ODD performance:
|
|
|
140 |
import torch
|
141 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
142 |
|
143 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
144 |
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
|
145 |
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
|
146 |
+
model.to(device)
|
147 |
|
148 |
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
|
149 |
|
150 |
inputs = tokenizer(sentence, return_tensors='pt')
|
151 |
+
inputs = {key: val.to(device) for key, val in inputs.items()}
|
152 |
input_ids = inputs['input_ids']
|
153 |
+
|
154 |
attention_mask = inputs['attention_mask']
|
155 |
|
156 |
outputs = model(**inputs)
|
|
|
157 |
|
158 |
logits, attention_mask = outputs.logits, attention_mask
|
159 |
relu_log = torch.log(1 + torch.relu(logits))
|
|
|
161 |
max_val, _ = torch.max(weighted_log, dim=1)
|
162 |
vector = max_val.squeeze()
|
163 |
|
164 |
+
|
165 |
cols = vector.nonzero().squeeze().cpu().tolist()
|
166 |
+
print("number of actual dimensions: ", len(cols))
|
167 |
weights = vector[cols].cpu().tolist()
|
168 |
|
169 |
+
d = {k: v for k, v in zip(cols, weights)}
|
170 |
+
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
|
171 |
+
bow_rep = []
|
172 |
+
for k, v in sorted_d.items():
|
173 |
+
bow_rep.append((reverse_voc[k], round(v,2)))
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
print("SPLADE BOW rep:\n", bow_rep)
|
176 |
```
|
177 |
|
178 |
## BEIR Zeroshot ODD performance:
|