Yeyito commited on
Commit
a0bf640
·
1 Parent(s): b56563d

Trying to prevent OOM on mmlu and gsm8k by halving seq len

Browse files
detect-pretrain-code-contamination/src/run.py CHANGED
@@ -75,7 +75,7 @@ def calculatePerplexity(sentence, model, tokenizer, gpu):
75
  all_prob.append(probability)
76
  return torch.exp(loss).item(), all_prob, loss.item()
77
 
78
- def sample_generation(sentence, model, tokenizer, args):
79
  half_sentence_index = math.ceil(len(sentence.split())*args['prefix_length'])
80
 
81
  if half_sentence_index > 0:
@@ -86,7 +86,11 @@ def sample_generation(sentence, model, tokenizer, args):
86
  input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
87
  input_ids = input_ids.to(model.device)
88
 
89
- output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
 
 
 
 
90
  # print(output)
91
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
92
 
@@ -98,9 +102,9 @@ def RMIA_1(text,target_loss,ref_loss,model1,tokenizer1,ratio_gen,neighbors_dl):
98
  result = torch.count_nonzero(target_losses_z < target_loss).item() / len(target_losses_z)
99
  return result
100
 
101
- def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen):
102
  cur_args = {'prefix_length': ratio_gen, 'num_z': 100, 'generate_args': {'do_sample': True}}
103
- neighbors = sample_generation(text, model2, tokenizer2, cur_args)
104
  neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
105
  return neighbors_dl
106
 
@@ -134,7 +138,7 @@ def evaluate_data(test_data, col_name, target_model, ref_model, ratio_gen, data_
134
  counter = 0
135
  for ex in tqdm(test_data):
136
  text = ex[col_name]
137
- new_ex = get_neighbors(text,inference2_pass[counter][2],model2,tokenizer2,ratio_gen)
138
  counter = counter + 1
139
  neighbors_dls.append(new_ex)
140
  unload_model(model2,tokenizer2)
 
75
  all_prob.append(probability)
76
  return torch.exp(loss).item(), all_prob, loss.item()
77
 
78
+ def sample_generation(sentence, model, tokenizer, args,data_name):
79
  half_sentence_index = math.ceil(len(sentence.split())*args['prefix_length'])
80
 
81
  if half_sentence_index > 0:
 
86
  input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
87
  input_ids = input_ids.to(model.device)
88
 
89
+ output = None
90
+ if data_name != "cais/mmlu" or data_name != "gsm8k":
91
+ output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
92
+ else:
93
+ output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
94
  # print(output)
95
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
96
 
 
102
  result = torch.count_nonzero(target_losses_z < target_loss).item() / len(target_losses_z)
103
  return result
104
 
105
+ def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen,data_name):
106
  cur_args = {'prefix_length': ratio_gen, 'num_z': 100, 'generate_args': {'do_sample': True}}
107
+ neighbors = sample_generation(text, model2, tokenizer2, cur_args,data_name)
108
  neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
109
  return neighbors_dl
110
 
 
138
  counter = 0
139
  for ex in tqdm(test_data):
140
  text = ex[col_name]
141
+ new_ex = get_neighbors(text,inference2_pass[counter][2],model2,tokenizer2,ratio_gen,data_name)
142
  counter = counter + 1
143
  neighbors_dls.append(new_ex)
144
  unload_model(model2,tokenizer2)