DawnC commited on
Commit
6ce2de6
1 Parent(s): 11ab9ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -3
app.py CHANGED
@@ -568,7 +568,6 @@ class BaseModel(nn.Module):
568
 
569
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
570
  super().__init__()
571
- self.device = device_mgr.get_optimal_device()
572
  self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
573
  self.feature_dim = self.backbone.classifier[1].in_features
574
  self.backbone.classifier = nn.Identity()
@@ -582,8 +581,6 @@ class BaseModel(nn.Module):
582
  nn.Linear(self.feature_dim, num_classes)
583
  )
584
 
585
- self.to(device)
586
-
587
  def forward(self, x):
588
  x = x.to(self.device)
589
  features = self.backbone(x)
 
568
 
569
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
570
  super().__init__()
 
571
  self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
572
  self.feature_dim = self.backbone.classifier[1].in_features
573
  self.backbone.classifier = nn.Identity()
 
581
  nn.Linear(self.feature_dim, num_classes)
582
  )
583
 
 
 
584
  def forward(self, x):
585
  x = x.to(self.device)
586
  features = self.backbone(x)