Yeyito commited on
Commit
28487fe
·
1 Parent(s): f6d7d3d

num_z = 50 for all

Browse files
detect-pretrain-code-contamination/src/run.py CHANGED
@@ -86,11 +86,7 @@ def sample_generation(sentence, model, tokenizer, args,data_name):
86
  input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
87
  input_ids = input_ids.to(model.device)
88
 
89
- output = None
90
- if data_name != "cais/mmlu" and data_name != "gsm8k":
91
- output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
92
- else:
93
- output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=int(args['num_z']/2), pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
94
  # print(output)
95
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
96
 
@@ -103,7 +99,7 @@ def RMIA_1(text,target_loss,ref_loss,model1,tokenizer1,ratio_gen,neighbors_dl):
103
  return result
104
 
105
  def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen,data_name):
106
- cur_args = {'prefix_length': ratio_gen, 'num_z': 100, 'generate_args': {'do_sample': True}}
107
  neighbors = sample_generation(text, model2, tokenizer2, cur_args,data_name)
108
  neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
109
  return neighbors_dl
 
86
  input_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)
87
  input_ids = input_ids.to(model.device)
88
 
89
+ output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index), min_new_tokens=1, num_return_sequences=int(args['num_z']), pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
 
 
 
 
90
  # print(output)
91
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
92
 
 
99
  return result
100
 
101
  def get_neighbors(text,ref_loss,model2,tokenizer2,ratio_gen,data_name):
102
+ cur_args = {'prefix_length': ratio_gen, 'num_z': 50, 'generate_args': {'do_sample': True}}
103
  neighbors = sample_generation(text, model2, tokenizer2, cur_args,data_name)
104
  neighbors_dl = DataLoader(neighbors, batch_size=32, shuffle=False)
105
  return neighbors_dl