Spaces:
Runtime error
Runtime error
Trying to prevent OOM on mmlu and gsm8k by halving seq len
Browse files
detect-pretrain-code-contamination/src/run.py
CHANGED
@@ -75,7 +75,7 @@ def calculatePerplexity(sentence, model, tokenizer, gpu):
|
|
75 |
all_prob.append(probability)
|
76 |
return torch.exp(loss).item(), all_prob, loss.item()
|
77 |
|
78 |
-
def sample_generation(sentence, model, tokenizer, args):
|
79 |
half_sentence_index = math.ceil(len(sentence.split())*args['prefix_length'])
|
80 |
|
81 |
if half_sentence_index > 0:
|
@@ -86,7 +86,11 @@ def sample_generation(sentence, model, tokenizer, args):
|
|
86 |
input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
|
87 |
input_ids = input_ids.to(model.device)
|
88 |
|
89 |
-
output =
|
|
|
|
|
|
|
|
|
90 |
# print(output)
|
91 |
complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
92 |
|
@@ -98,9 +102,9 @@ def RMIA_1(text,target_loss,ref_loss,model1,tokenizer1,ratio_gen,neighbors_dl):
|
|
98 |
result = torch.count_nonzero(target_losses_z < target_loss).item() / len(target_losses_z)
|
99 |
return result
|
100 |
|
101 |
-
def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen):
|
102 |
cur_args = {'prefix_length': ratio_gen, 'num_z': 100, 'generate_args': {'do_sample': True}}
|
103 |
-
neighbors = sample_generation(text, model2, tokenizer2, cur_args)
|
104 |
neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
|
105 |
return neighbors_dl
|
106 |
|
@@ -134,7 +138,7 @@ def evaluate_data(test_data, col_name, target_model, ref_model, ratio_gen, data_
|
|
134 |
counter = 0
|
135 |
for ex in tqdm(test_data):
|
136 |
text = ex[col_name]
|
137 |
-
new_ex = get_neighbors(text,inference2_pass[counter][2],model2,tokenizer2,ratio_gen)
|
138 |
counter = counter + 1
|
139 |
neighbors_dls.append(new_ex)
|
140 |
unload_model(model2,tokenizer2)
|
|
|
75 |
all_prob.append(probability)
|
76 |
return torch.exp(loss).item(), all_prob, loss.item()
|
77 |
|
78 |
+
def sample_generation(sentence, model, tokenizer, args,data_name):
|
79 |
half_sentence_index = math.ceil(len(sentence.split())*args['prefix_length'])
|
80 |
|
81 |
if half_sentence_index > 0:
|
|
|
86 |
input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
|
87 |
input_ids = input_ids.to(model.device)
|
88 |
|
89 |
+
output = None
|
90 |
+
if data_name != "cais/mmlu" or data_name != "gsm8k":
|
91 |
+
output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
92 |
+
else:
|
93 |
+
output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
94 |
# print(output)
|
95 |
complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
96 |
|
|
|
102 |
result = torch.count_nonzero(target_losses_z < target_loss).item() / len(target_losses_z)
|
103 |
return result
|
104 |
|
105 |
+
def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen,data_name):
|
106 |
cur_args = {'prefix_length': ratio_gen, 'num_z': 100, 'generate_args': {'do_sample': True}}
|
107 |
+
neighbors = sample_generation(text, model2, tokenizer2, cur_args,data_name)
|
108 |
neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
|
109 |
return neighbors_dl
|
110 |
|
|
|
138 |
counter = 0
|
139 |
for ex in tqdm(test_data):
|
140 |
text = ex[col_name]
|
141 |
+
new_ex = get_neighbors(text,inference2_pass[counter][2],model2,tokenizer2,ratio_gen,data_name)
|
142 |
counter = counter + 1
|
143 |
neighbors_dls.append(new_ex)
|
144 |
unload_model(model2,tokenizer2)
|