zamborg commited on
Commit
9fa3fe8
·
1 Parent(s): 0674c7e

converted to tensor

Browse files
Files changed (1) hide show
  1. model.py +2 -1
model.py CHANGED
@@ -63,7 +63,8 @@ class VirTexModel():
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
66
- cap_tokens = self.tokenizer.encode(prompt)
 
67
  subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
68
 
69
  predictions: List[Dict[str, Any]] = []
 
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
66
+ cap_tokens = self.tokenizer.encode(prompt)
67
+ cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
68
  subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
69
 
70
  predictions: List[Dict[str, Any]] = []