Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,33 +22,24 @@ else:
|
|
22 |
logging.error(f"Model file not found: {model_path}")
|
23 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
24 |
|
25 |
-
# Create
|
26 |
-
|
27 |
-
|
28 |
-
config.id2label = {str(i): label for i, label in enumerate(labels)}
|
29 |
-
config.label2id = {label: str(i) for i, label in enumerate(labels)}
|
30 |
-
logging.info(f"Custom config created with {len(labels)} labels")
|
31 |
|
32 |
-
# Load the model with
|
33 |
-
logging.info("Loading the model with custom
|
34 |
-
model = ViTForImageClassification(
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
try:
|
37 |
# Load the state dict
|
38 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
39 |
-
|
40 |
-
# Check if the state dict keys match the model's keys
|
41 |
-
model_keys = set(model.state_dict().keys())
|
42 |
-
loaded_keys = set(state_dict.keys())
|
43 |
-
|
44 |
-
if model_keys != loaded_keys:
|
45 |
-
logging.warning("Mismatch in state dict keys. Attempting to adjust...")
|
46 |
-
# Adjust keys if necessary (e.g., remove 'module.' prefix if it exists)
|
47 |
-
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
48 |
-
model.load_state_dict(new_state_dict)
|
49 |
-
else:
|
50 |
-
model.load_state_dict(state_dict)
|
51 |
-
|
52 |
logging.info("Model loaded successfully")
|
53 |
except Exception as e:
|
54 |
logging.error(f"Error loading model: {str(e)}")
|
@@ -57,10 +48,11 @@ except Exception as e:
|
|
57 |
model.eval()
|
58 |
logging.info("Model set to evaluation mode")
|
59 |
|
60 |
-
# Load
|
61 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
62 |
logging.info("Feature extractor loaded")
|
63 |
|
|
|
64 |
logging.info("Model and feature extractor loaded successfully")
|
65 |
|
66 |
# Define the prediction function
|
|
|
22 |
logging.error(f"Model file not found: {model_path}")
|
23 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
24 |
|
25 |
+
# Create label mappings consistent with training
|
26 |
+
id2label = {str(i): label for i, label in enumerate(labels)}
|
27 |
+
label2id = {label: str(i) for i, label in enumerate(labels)}
|
|
|
|
|
|
|
28 |
|
29 |
+
# Load the model with custom label mapping
|
30 |
+
logging.info("Loading the model with custom label mapping")
|
31 |
+
model = ViTForImageClassification.from_pretrained(
|
32 |
+
"google/vit-base-patch16-224-in21k",
|
33 |
+
num_labels=len(labels),
|
34 |
+
id2label=id2label,
|
35 |
+
label2id=label2id,
|
36 |
+
ignore_mismatched_sizes=True
|
37 |
+
)
|
38 |
|
39 |
try:
|
40 |
# Load the state dict
|
41 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
42 |
+
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
logging.info("Model loaded successfully")
|
44 |
except Exception as e:
|
45 |
logging.error(f"Error loading model: {str(e)}")
|
|
|
48 |
model.eval()
|
49 |
logging.info("Model set to evaluation mode")
|
50 |
|
51 |
+
# Load feature extractor
|
52 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
53 |
logging.info("Feature extractor loaded")
|
54 |
|
55 |
+
|
56 |
logging.info("Model and feature extractor loaded successfully")
|
57 |
|
58 |
# Define the prediction function
|