ethanlshen commited on
Commit
2645fa8
1 Parent(s): 70c86a7

Update superposed/llama/superpose.py

Browse files
Files changed (1) hide show
  1. superposed/llama/superpose.py +1 -1
superposed/llama/superpose.py CHANGED
@@ -198,7 +198,7 @@ class Superpose(nn.Module):
198
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
199
  next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
200
  # Set all other probs to 0
201
- new_probs_map = torch.zeros(probs.shape).bool()
202
  new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
203
  new_probs = torch.where(new_probs_map, probs, 0)
204
  # Renormalize
 
198
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
199
  next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
200
  # Set all other probs to 0
201
+ new_probs_map = torch.zeros(probs.shape, device="cuda").bool()
202
  new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
203
  new_probs = torch.where(new_probs_map, probs, 0)
204
  # Renormalize