JMalott commited on
Commit
5629479
1 Parent(s): a1eeb65

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +3 -4
min_dalle/min_dalle.py CHANGED
@@ -14,8 +14,8 @@ import time
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
17
- torch.backends.cudnn.enabled = True
18
- torch.backends.cudnn.allow_tf32 = True
19
 
20
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
21
  IMAGE_TOKEN_COUNT = 256
@@ -252,8 +252,7 @@ class MinDalle:
252
  token_index=token_indices[[i]]
253
  )
254
 
255
- del attention_state
256
- del image_tokens
257
 
258
 
259
  with torch.cuda.amp.autocast(dtype=torch.float16):
 
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
17
+ torch.backends.cudnn.enabled = False
18
+ torch.backends.cudnn.allow_tf32 = False
19
 
20
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
21
  IMAGE_TOKEN_COUNT = 256
 
252
  token_index=token_indices[[i]]
253
  )
254
 
255
+
 
256
 
257
 
258
  with torch.cuda.amp.autocast(dtype=torch.float16):