DumbledoreWiz commited on
Commit
0d3f848
·
verified ·
1 Parent(s): 4f2b7b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -23
app.py CHANGED
@@ -22,33 +22,24 @@ else:
22
  logging.error(f"Model file not found: {model_path}")
23
  raise FileNotFoundError(f"Model file not found: {model_path}")
24
 
25
- # Create a custom configuration
26
- config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
27
- config.num_labels = len(labels)
28
- config.id2label = {str(i): label for i, label in enumerate(labels)}
29
- config.label2id = {label: str(i) for i, label in enumerate(labels)}
30
- logging.info(f"Custom config created with {len(labels)} labels")
31
 
32
- # Load the model with the custom configuration
33
- logging.info("Loading the model with custom configuration")
34
- model = ViTForImageClassification(config)
 
 
 
 
 
 
35
 
36
  try:
37
  # Load the state dict
38
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
39
-
40
- # Check if the state dict keys match the model's keys
41
- model_keys = set(model.state_dict().keys())
42
- loaded_keys = set(state_dict.keys())
43
-
44
- if model_keys != loaded_keys:
45
- logging.warning("Mismatch in state dict keys. Attempting to adjust...")
46
- # Adjust keys if necessary (e.g., remove 'module.' prefix if it exists)
47
- new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
48
- model.load_state_dict(new_state_dict)
49
- else:
50
- model.load_state_dict(state_dict)
51
-
52
  logging.info("Model loaded successfully")
53
  except Exception as e:
54
  logging.error(f"Error loading model: {str(e)}")
@@ -57,10 +48,11 @@ except Exception as e:
57
  model.eval()
58
  logging.info("Model set to evaluation mode")
59
 
60
- # Load or create feature extractor
61
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
62
  logging.info("Feature extractor loaded")
63
 
 
64
  logging.info("Model and feature extractor loaded successfully")
65
 
66
  # Define the prediction function
 
22
  logging.error(f"Model file not found: {model_path}")
23
  raise FileNotFoundError(f"Model file not found: {model_path}")
24
 
25
+ # Create label mappings consistent with training
26
+ id2label = {str(i): label for i, label in enumerate(labels)}
27
+ label2id = {label: str(i) for i, label in enumerate(labels)}
 
 
 
28
 
29
+ # Load the model with custom label mapping
30
+ logging.info("Loading the model with custom label mapping")
31
+ model = ViTForImageClassification.from_pretrained(
32
+ "google/vit-base-patch16-224-in21k",
33
+ num_labels=len(labels),
34
+ id2label=id2label,
35
+ label2id=label2id,
36
+ ignore_mismatched_sizes=True
37
+ )
38
 
39
  try:
40
  # Load the state dict
41
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
42
+ model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
43
  logging.info("Model loaded successfully")
44
  except Exception as e:
45
  logging.error(f"Error loading model: {str(e)}")
 
48
  model.eval()
49
  logging.info("Model set to evaluation mode")
50
 
51
+ # Load feature extractor
52
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
53
  logging.info("Feature extractor loaded")
54
 
55
+
56
  logging.info("Model and feature extractor loaded successfully")
57
 
58
  # Define the prediction function