ethanlshen commited on
Commit
8b0fc91
1 Parent(s): 2645fa8

Set devices

Browse files
Files changed (1) hide show
  1. superposed/llama/superpose.py +4 -4
superposed/llama/superpose.py CHANGED
@@ -50,8 +50,8 @@ class Superpose(nn.Module):
50
  self.alive_seq = initial_tokens
51
  self.fin_seq = initial_tokens
52
  self.smoothing = smoothing
53
- self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts)
54
- self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"))
55
  self.alpha = alpha
56
  self.verbose = verbose
57
  self.penalty = penalty
@@ -214,7 +214,7 @@ class Superpose(nn.Module):
214
  SUperposition matrix
215
  """
216
  # Create superposition matrix
217
- mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size)
218
  # Convert draft log probs to probabilities
219
  weightings = log_prob_to_prob(self.alive_log_probs)
220
  # Update probabilities in superposition matrix with draft probabilities
@@ -242,7 +242,7 @@ class Superpose(nn.Module):
242
  # Start timer
243
  start_time = datetime.now()
244
  # Create distribution matrix
245
- next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000)
246
  if probs is not None:
247
  # Loop over all prefixes
248
  for p_idx in range(len(alive_seq)):
 
50
  self.alive_seq = initial_tokens
51
  self.fin_seq = initial_tokens
52
  self.smoothing = smoothing
53
+ self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts, device="cuda")
54
+ self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"), device="cuda")
55
  self.alpha = alpha
56
  self.verbose = verbose
57
  self.penalty = penalty
 
214
  SUperposition matrix
215
  """
216
  # Create superposition matrix
217
+ mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size, device="cuda")
218
  # Convert draft log probs to probabilities
219
  weightings = log_prob_to_prob(self.alive_log_probs)
220
  # Update probabilities in superposition matrix with draft probabilities
 
242
  # Start timer
243
  start_time = datetime.now()
244
  # Create distribution matrix
245
+ next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000, device="cuda")
246
  if probs is not None:
247
  # Loop over all prefixes
248
  for p_idx in range(len(alive_seq)):