File size: 2,120 Bytes
d1b31ce c83f8fa d1b31ce c83f8fa d1b31ce c83f8fa d1b31ce c83f8fa d1b31ce c83f8fa d1b31ce c83f8fa d1b31ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""
File: model.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module provides functions for loading and processing a pre-trained deep learning model
for facial expression recognition.
License: MIT License
"""
import torch
import requests
from PIL import Image
from torchvision import transforms
from pytorch_grad_cam import GradCAM
# Importing necessary components for the Gradio app
from app.config import config_data
from app.model_architectures import ResNet50, LSTMPyTorch
def load_model(model_url, model_path):
try:
with requests.get(model_url, stream=True) as response:
with open(model_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
return model_path
except Exception as e:
print(f"Error loading model: {e}")
return None
path_static = load_model(config_data.model_static_url, config_data.model_static_path)
pth_model_static = ResNet50(7, channels=3)
pth_model_static.load_state_dict(torch.load(path_static))
pth_model_static.eval()
path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
pth_model_dynamic = LSTMPyTorch()
pth_model_dynamic.load_state_dict(torch.load(path_dynamic))
pth_model_dynamic.eval()
target_layers = [pth_model_static.layer4]
cam = GradCAM(model=pth_model_static, target_layers=target_layers)
def pth_processing(fp):
class PreprocessInput(torch.nn.Module):
def init(self):
super(PreprocessInput, self).init()
def forward(self, x):
x = x.to(torch.float32)
x = torch.flip(x, dims=(0,))
x[0, :, :] -= 91.4953
x[1, :, :] -= 103.8827
x[2, :, :] -= 131.0912
return x
def get_img_torch(img, target_size=(224, 224)):
transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
img = img.resize(target_size, Image.Resampling.NEAREST)
img = transform(img)
img = torch.unsqueeze(img, 0)
return img
return get_img_torch(fp)
|