patrickramos commited on
Commit
82be63b
1 Parent(s): 617c3e2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -1
README.md CHANGED
@@ -99,7 +99,7 @@ inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt'
99
 
100
  # extract patch embeddings
101
  with torch.no_grad():
102
- patch_embeddings = rearrange(model(**inputs).last_hidden_state, 'b d h w -> b (h w) d').cpu()
103
 
104
  # classify
105
  pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
 
99
 
100
  # extract patch embeddings
101
  with torch.no_grad():
102
+ patch_embeddings = model(**inputs).last_hidden_state[0].permute(1, 2, 0).view(7*7, 2048).cpu()
103
 
104
  # classify
105
  pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))