ariG23498 commited on
Commit
bb9d17d
1 Parent(s): 5d23e43

updating the langugage model with better prompting

Browse files
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -28,27 +28,38 @@ def run_language_model(edit_prompt, device):
28
  )
29
  tokenizer = AutoTokenizer.from_pretrained(language_model_id)
30
  messages = [
31
- {
32
- "role": "system",
33
- "content": "Follow the examples and return the expected output",
34
- },
35
  {"role": "user", "content": "swap mountain and lion"}, # example 1
36
  {"role": "assistant", "content": "mountain, lion"}, # example 1
37
  {"role": "user", "content": "change the dog with cat"}, # example 2
38
  {"role": "assistant", "content": "dog, cat"}, # example 2
39
- {"role": "user", "content": "replace the human with a boat"}, # example 3
40
- {"role": "assistant", "content": "human, boat"}, # example 3
41
- {"role": "user", "content": edit_prompt},
 
 
 
 
 
 
42
  ]
43
  text = tokenizer.apply_chat_template(
44
- messages, tokenize=False, add_generation_prompt=True
 
 
45
  )
46
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
47
- generated_ids = language_model.generate(model_inputs.input_ids, max_new_tokens=512)
48
- generated_ids = [
49
- output_ids[len(input_ids) :]
50
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
51
- ]
 
 
 
 
 
 
52
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
  to_replace, replace_with = response.split(", ")
54
 
 
28
  )
29
  tokenizer = AutoTokenizer.from_pretrained(language_model_id)
30
  messages = [
31
+ {"role": "system", "content": "Follow the examples and return the expected output"},
 
 
 
32
  {"role": "user", "content": "swap mountain and lion"}, # example 1
33
  {"role": "assistant", "content": "mountain, lion"}, # example 1
34
  {"role": "user", "content": "change the dog with cat"}, # example 2
35
  {"role": "assistant", "content": "dog, cat"}, # example 2
36
+ {"role": "user", "content": "change the cat with a dog"}, # example 3
37
+ {"role": "assistant", "content": "cat, dog"}, # example 3
38
+ {"role": "user", "content": "replace the human with a boat"}, # example 4
39
+ {"role": "assistant", "content": "human, boat"}, # example 4
40
+ {"role": "user", "content": "in the above example change the background to the alps"}, # example 5
41
+ {"role": "assistant", "content": "background, alps"}, # example 5
42
+ {"role": "user", "content": "edit the house into a mansion"}, # example 6
43
+ {"role": "assistant", "content": "house, a mansion"}, # example 6
44
+ {"role": "user", "content": edit_prompt}
45
  ]
46
  text = tokenizer.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
  )
51
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
52
+ with torch.no_grad():
53
+ generated_ids = language_model.generate(
54
+ model_inputs.input_ids,
55
+ max_new_tokens=512,
56
+ temperature=0.0,
57
+ do_sample=False
58
+ )
59
+
60
+ generated_ids = [
61
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
62
+ ]
63
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
64
  to_replace, replace_with = response.split(", ")
65