Crystalcareai commited on
Commit
495a5d0
1 Parent(s): 904dcda

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +103 -31
inference.py CHANGED
@@ -1,49 +1,121 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
- from accelerate import infer_auto_device_map, init_empty_weights, dispatch_model
4
 
5
- model_path = "Crystalcareai/Quiet-Star-Custom"
 
 
 
 
 
 
 
 
 
6
 
7
  n_ahead = 8
8
  n_ahead_talk = 4
9
  merged_talk_heads = True
10
 
11
- model = AutoModelForCausalLM.from_pretrained(model_path,
12
- max_thoughts=n_ahead + n_ahead_talk + 1,
13
- merged_talk_heads=merged_talk_heads,
14
- merged_lm_and_talk_heads=False,
15
- merged_lm_and_think_heads=True,
16
- use_concat_talk_head=True,
17
- use_shallow_think=True,
18
- use_shallow_talk=False,
19
- use_complex_think_head=False,
20
- use_complex_talk_head=True,
21
- use_weighted_talk_head=True,
22
- trust_remote_code=True,
23
- torch_dtype=torch.bfloat16,
24
- device_map="auto",
25
- )
26
-
27
- model.eval()
28
 
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_path)
30
- model.tokenizer = tokenizer # Set the tokenizer attribute of the model
31
 
32
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
 
 
 
 
 
 
 
 
 
33
 
34
- # Convert prompt to tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  prompt_template = "[INST] {prompt} [/INST]"
36
- prompt = "It is not always easy to see who is related to whom -- and in which ways. The following argument pertains to this question: To begin with, Lesley is a close friend of Fernando. Moreover, being a close friend of Fernando or a schoolmate of Lowell is sufficient for being a great-grandfather of Leroy. It follows that Lesley is a great-grandfather of Leroy. Is the argument, given the explicitly stated premises, deductively valid or invalid?"
37
 
38
- input_ids = tokenizer(
39
- prompt_template.format(prompt=prompt),
40
- return_tensors='pt'
41
- ).input_ids.to(model.device)
 
 
 
 
 
 
 
42
 
43
- attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
44
 
45
- max_length = 256
46
 
47
- output_ids, _ = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, streamer=streamer)
48
 
49
- print(tokenizer.decode(output_ids[0], skip_special_tokens=False))
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
 
3
 
4
+ def compute_memory_used_pct(device):
5
+ memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
6
+ memory_pct = (
7
+ memory_used
8
+ / (torch.cuda.get_device_properties(device).total_memory / (1024**3))
9
+ * 100
10
+ )
11
+ return memory_pct
12
+
13
+ model_path = "./out"
14
 
15
  n_ahead = 8
16
  n_ahead_talk = 4
17
  merged_talk_heads = True
18
 
19
+ # Load the model
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_path,
22
+ max_thoughts=n_ahead + n_ahead_talk + 1,
23
+ merged_talk_heads=merged_talk_heads,
24
+ merged_lm_and_talk_heads=False,
25
+ merged_lm_and_think_heads=True,
26
+ use_concat_talk_head=True,
27
+ use_shallow_think=True,
28
+ use_shallow_talk=False,
29
+ use_complex_think_head=False,
30
+ use_complex_talk_head=True,
31
+ use_weighted_talk_head=True,
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.bfloat16,
34
+ device_map="auto",
35
+ )
36
 
37
+ # Load the tokenizer and assign it to the model instance for compatibility
38
  tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ model.tokenizer = tokenizer
40
 
41
+ model.use_end_thought_token = True
42
+ model.use_start_thought_token = True
43
+ model.wandb_enabled = True
44
+ model.n_ahead = n_ahead
45
+ model.n_passes = 2
46
+ model.eval_mode = True
47
+ model.first_run = False
48
+ model.kill_after = 100
49
+ model.rm_initialized = True
50
+ model.original_mode = False
51
 
52
+ # Custom generate function
53
+ def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
54
+ with torch.no_grad():
55
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
56
+ for cur_token_idx in range(max_new_tokens):
57
+ # Sample the next token
58
+ new_ids = model(
59
+ input_ids[~finished_generating],
60
+ attention_mask=attention_mask[~finished_generating]
61
+ )['logits']
62
+ # Mask out the start and end thought tokens so we don't accidentally sample them
63
+ new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
64
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
65
+ # Find the index of the last token that is not padding
66
+ base_answer_ids = input_ids[answer_idx]
67
+ new_answer_ids = new_ids[list_idx]
68
+ last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
69
+
70
+ new_ids_sampled = torch.multinomial(
71
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
72
+ # Assign the new id to the last token
73
+ if last_token_idx + 1 >= len(base_answer_ids):
74
+ # Add padding everywhere
75
+ new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
76
+ device=input_ids.device)
77
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
78
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
79
+ attention_mask[answer_idx, last_token_idx + 1] = 1
80
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
81
+ if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
82
+ finished_generating[answer_idx] = 1
83
+ # Check if the end token is generated
84
+ if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
85
+ finished_generating[answer_idx] = 1
86
+ if finished_generating.all():
87
+ break
88
+ streamer.put(new_ids_sampled)
89
+ return input_ids, attention_mask
90
+
91
+ # Formulate your prompt
92
  prompt_template = "[INST] {prompt} [/INST]"
 
93
 
94
+ prompt = "You're standing on the surface of the Earth. "\
95
+ "You walk one mile south, one mile west and one mile north. "\
96
+ "You end up exactly where you started. Where are you?"
97
+
98
+ # Convert prompt to tokens
99
+ tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device)
100
+
101
+ # Generate an attention mask
102
+ attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device)
103
+
104
+ streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
105
 
106
+ # Generate output using the custom generate function
107
+ output_ids, _ = custom_generate(
108
+ model,
109
+ input_ids=tokens,
110
+ attention_mask=attention_mask,
111
+ max_new_tokens=512,
112
+ streamer=streamer,
113
+ temperature=0.9,
114
+ )
115
 
116
+ generated_text = ""
117
 
118
+ print() # Print a newline after streaming is complete
119
 
120
+ # Cleanup if necessary
121
+ torch.cuda.empty_cache()