DumbledoreWiz commited on
Commit
ebb842f
·
verified ·
1 Parent(s): 4e3ba00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  import gradio as gr
4
  from PIL import Image
5
  import os
@@ -26,21 +26,20 @@ else:
26
  id2label = {str(i): label for i, label in enumerate(labels)}
27
  label2id = {label: str(i) for i, label in enumerate(labels)}
28
 
29
- # Load the model with custom label mapping
30
- logging.info("Loading the model with custom label mapping")
31
- model = ViTForImageClassification.from_pretrained(
32
- "google/vit-base-patch16-224-in21k",
33
- num_labels=len(labels),
34
- id2label=id2label,
35
- label2id=label2id,
36
- ignore_mismatched_sizes=True
37
- )
38
 
39
  try:
40
- # Load the state dict
41
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
42
  model.load_state_dict(state_dict)
43
- logging.info("Model loaded successfully")
44
  except Exception as e:
45
  logging.error(f"Error loading model: {str(e)}")
46
  raise
 
1
  import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
3
  import gradio as gr
4
  from PIL import Image
5
  import os
 
26
  id2label = {str(i): label for i, label in enumerate(labels)}
27
  label2id = {label: str(i) for i, label in enumerate(labels)}
28
 
29
+ # Create a configuration for the model
30
+ config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
31
+ config.num_labels = len(labels)
32
+ config.id2label = id2label
33
+ config.label2id = label2id
34
+
35
+ # Initialize the model with the configuration
36
+ model = ViTForImageClassification(config)
 
37
 
38
  try:
39
+ # Load the state dict of the fine-tuned model
40
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
41
  model.load_state_dict(state_dict)
42
+ logging.info("Fine-tuned model loaded successfully")
43
  except Exception as e:
44
  logging.error(f"Error loading model: {str(e)}")
45
  raise