sayakpaul HF staff commited on
Commit
1548a67
1 Parent(s): c4b2b37

remove cuda

Browse files
Files changed (1) hide show
  1. generic_utils.py +5 -5
generic_utils.py CHANGED
@@ -24,10 +24,10 @@ def show_cam_on_image(img, mask):
24
 
25
 
26
  # initialize ViT pretrained
27
- model = vit_LRP(pretrained=True).cuda()
28
  model.eval()
29
  attribution_generator = LRP(model)
30
- model_baseline = vit(pretrained=True).cuda()
31
  model_baseline.eval()
32
  baselines_generator = Baselines(model_baseline)
33
 
@@ -37,16 +37,16 @@ def generate_visualization(
37
  ):
38
  if LRP:
39
  transformer_attribution = attribution_generator.generate_LRP(
40
- original_image.unsqueeze(0).cuda(), method=method, index=class_index
41
  ).detach()
42
  else:
43
  if method == "gradcam":
44
  transformer_attribution = baselines_generator.generate_cam_attn(
45
- original_image.unsqueeze(0).cuda(), index=class_index
46
  ).detach()
47
  else:
48
  transformer_attribution = baselines_generator.generate_rollout(
49
- original_image.unsqueeze(0).cuda()
50
  ).detach()
51
  if method != "full":
52
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
 
24
 
25
 
26
  # initialize ViT pretrained
27
+ model = vit_LRP(pretrained=True)
28
  model.eval()
29
  attribution_generator = LRP(model)
30
+ model_baseline = vit(pretrained=True)
31
  model_baseline.eval()
32
  baselines_generator = Baselines(model_baseline)
33
 
 
37
  ):
38
  if LRP:
39
  transformer_attribution = attribution_generator.generate_LRP(
40
+ original_image.unsqueeze(0), method=method, index=class_index
41
  ).detach()
42
  else:
43
  if method == "gradcam":
44
  transformer_attribution = baselines_generator.generate_cam_attn(
45
+ original_image.unsqueeze(0), index=class_index
46
  ).detach()
47
  else:
48
  transformer_attribution = baselines_generator.generate_rollout(
49
+ original_image.unsqueeze(0)
50
  ).detach()
51
  if method != "full":
52
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)