Spaces:
Runtime error
Runtime error
Set devices
Browse files
superposed/llama/superpose.py
CHANGED
@@ -50,8 +50,8 @@ class Superpose(nn.Module):
|
|
50 |
self.alive_seq = initial_tokens
|
51 |
self.fin_seq = initial_tokens
|
52 |
self.smoothing = smoothing
|
53 |
-
self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts)
|
54 |
-
self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"))
|
55 |
self.alpha = alpha
|
56 |
self.verbose = verbose
|
57 |
self.penalty = penalty
|
@@ -214,7 +214,7 @@ class Superpose(nn.Module):
|
|
214 |
SUperposition matrix
|
215 |
"""
|
216 |
# Create superposition matrix
|
217 |
-
mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size)
|
218 |
# Convert draft log probs to probabilities
|
219 |
weightings = log_prob_to_prob(self.alive_log_probs)
|
220 |
# Update probabilities in superposition matrix with draft probabilities
|
@@ -242,7 +242,7 @@ class Superpose(nn.Module):
|
|
242 |
# Start timer
|
243 |
start_time = datetime.now()
|
244 |
# Create distribution matrix
|
245 |
-
next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000)
|
246 |
if probs is not None:
|
247 |
# Loop over all prefixes
|
248 |
for p_idx in range(len(alive_seq)):
|
|
|
50 |
self.alive_seq = initial_tokens
|
51 |
self.fin_seq = initial_tokens
|
52 |
self.smoothing = smoothing
|
53 |
+
self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts, device="cuda")
|
54 |
+
self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"), device="cuda")
|
55 |
self.alpha = alpha
|
56 |
self.verbose = verbose
|
57 |
self.penalty = penalty
|
|
|
214 |
SUperposition matrix
|
215 |
"""
|
216 |
# Create superposition matrix
|
217 |
+
mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size, device="cuda")
|
218 |
# Convert draft log probs to probabilities
|
219 |
weightings = log_prob_to_prob(self.alive_log_probs)
|
220 |
# Update probabilities in superposition matrix with draft probabilities
|
|
|
242 |
# Start timer
|
243 |
start_time = datetime.now()
|
244 |
# Create distribution matrix
|
245 |
+
next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000, device="cuda")
|
246 |
if probs is not None:
|
247 |
# Loop over all prefixes
|
248 |
for p_idx in range(len(alive_seq)):
|