import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer def compute_memory_used_pct(device): memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) memory_pct = ( memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024**3)) * 100 ) return memory_pct model_path = "./out" n_ahead = 8 n_ahead_talk = 4 merged_talk_heads = True # Load the model model = AutoModelForCausalLM.from_pretrained( model_path, max_thoughts=n_ahead + n_ahead_talk + 1, merged_talk_heads=merged_talk_heads, merged_lm_and_talk_heads=False, merged_lm_and_think_heads=True, use_concat_talk_head=True, use_shallow_think=True, use_shallow_talk=False, use_complex_think_head=False, use_complex_talk_head=True, use_weighted_talk_head=True, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ) # Load the tokenizer and assign it to the model instance for compatibility tokenizer = AutoTokenizer.from_pretrained(model_path) model.tokenizer = tokenizer model.use_end_thought_token = True model.use_start_thought_token = True model.wandb_enabled = True model.n_ahead = n_ahead model.n_passes = 2 model.eval_mode = True model.first_run = False model.kill_after = 100 model.rm_initialized = True model.original_mode = False # Custom generate function def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs): with torch.no_grad(): finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) for cur_token_idx in range(max_new_tokens): # Sample the next token new_ids = model( input_ids[~finished_generating], attention_mask=attention_mask[~finished_generating] )['logits'] # Mask out the start and end thought tokens so we don't accidentally sample them new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf") for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): # Find the index of the last token that is not padding base_answer_ids = input_ids[answer_idx] new_answer_ids = new_ids[list_idx] last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() new_ids_sampled = torch.multinomial( torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1) # Assign the new id to the last token if last_token_idx + 1 >= len(base_answer_ids): # Add padding everywhere new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, new_padding], dim=-1) attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) attention_mask[answer_idx, last_token_idx + 1] = 1 input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id: finished_generating[answer_idx] = 1 # Check if the end token is generated if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"): finished_generating[answer_idx] = 1 if finished_generating.all(): break streamer.put(new_ids_sampled) return input_ids, attention_mask # Formulate your prompt prompt_template = "[INST] {prompt} [/INST]" prompt = "You're standing on the surface of the Earth. "\ "You walk one mile south, one mile west and one mile north. "\ "You end up exactly where you started. Where are you?" # Convert prompt to tokens tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device) # Generate an attention mask attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device) streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) # Generate output using the custom generate function output_ids, _ = custom_generate( model, input_ids=tokens, attention_mask=attention_mask, max_new_tokens=512, streamer=streamer, temperature=0.9, ) generated_text = "" print() # Print a newline after streaming is complete # Cleanup if necessary torch.cuda.empty_cache()