TADBot / FER /prediction.py
ryefoxlime's picture
updated the python project version
8eb51e6
raw
history blame
2.07 kB
import torch
import os
from torchvision import transforms
import numpy as np
# Checking for all types of devices available
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
image_arr = []
for foldername, subfolders, filenames in os.walk("../FER/Images/"):
for filename in filenames:
# Construct the full path to the file
file_path = os.path.join(foldername, filename)
image_arr.append(f"{file_path}")
def predict(model, image_path):
from face_detection import face_detection
with torch.no_grad():
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=1, scale=(0.05, 0.05)),
]
)
face = face_detection(image_path)
image_tensor = transform(face).unsqueeze(0)
image_tensor = image_tensor.to(device)
model.eval()
img_pred = model(image_tensor)
topk = (3,)
with torch.no_grad():
maxk = max(topk)
# batch_size = target.size(0)
_, pred = img_pred.topk(maxk, 1, True, True)
pred = pred.t()
img_pred = pred
img_pred = img_pred.squeeze().cpu().numpy()
im_pre_label = np.array(img_pred)
y_pred = im_pre_label.flatten()
emotions = {
0: "Surprise",
1: "Fear",
2: "Disgust",
3: "Happy",
4: "Sad",
5: "Angry",
6: "Neutral",
}
labels = []
for i in y_pred:
labels.append(emotions.get(i))
print(
f"-->Image Path {image_path} [!] The predicted labels are {y_pred} and the label is {labels}"
)
return