|
|
|
|
|
|
|
|
|
import cv2, os, torch, re |
|
import matplotlib.pyplot as plt |
|
from scipy.ndimage import zoom |
|
import numpy as np |
|
from model_two import MakiAlexNet |
|
from tqdm import tqdm |
|
|
|
|
|
TOP_ACCURACY_PERCENTILE = 10 |
|
|
|
TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg" |
|
MODEL_PARAMS = "alexnet_2.0.pth" |
|
GIF_STORE = "dataset/gifs2/" |
|
TRAIN_STORE = "dataset/root/train/" |
|
|
|
model = MakiAlexNet() |
|
model.load_state_dict(torch.load(MODEL_PARAMS)) |
|
model.eval() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
print("Running on cuda") |
|
|
|
|
|
print(dir(model)) |
|
|
|
for name, module in model.named_modules(): |
|
|
|
print(name) |
|
|
|
|
|
def extract_file_paths(filename): |
|
"""With aid from https://regex101.com/, regex.""" |
|
extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)" |
|
result = re.search(extractor_reg, filename) |
|
frame_no = result.group(4) |
|
frame_name = result.group(1) |
|
video_no = result.group(2) |
|
return frame_no, frame_name, video_no |
|
|
|
|
|
def create_mp4_from_frames(file_name, frames): |
|
"""Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """ |
|
print("Sorted frames: ", sorted(frames)) |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
height, width, _ = cv2.imread(frames[0]).shape |
|
fps = 20 |
|
video_path = os.path.join(os.getcwd(), "dataset", "gifs2", f"{file_name}.mp4") |
|
video = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) |
|
for frame_path in sorted(frames): |
|
|
|
image = cv2.imread(frame_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
video.write(image) |
|
|
|
|
|
video.release() |
|
|
|
|
|
|
|
current_video_name = None |
|
selected_frames = [] |
|
for image_filename in tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"): |
|
|
|
frame_no, frame_name, video_no = extract_file_paths(image_filename) |
|
obtained_video_name = video_no+"vid"+frame_name |
|
if current_video_name != obtained_video_name: |
|
|
|
if selected_frames: |
|
filename = f"{current_video_name}" |
|
|
|
if current_video_name: |
|
create_mp4_from_frames(filename, selected_frames) |
|
|
|
selected_frames = [] |
|
current_video_name = obtained_video_name |
|
|
|
|
|
|
|
|
|
img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0) |
|
X = torch.einsum("BWHC->BCWH", img) |
|
if torch.cuda.is_available(): |
|
X = X.cuda() |
|
|
|
output = model(X) |
|
|
|
|
|
|
|
conv = model.layer_outputs['Conv2d'] |
|
pred = model.layer_outputs["Linear"] |
|
pred_weights, pred_bias = model.f_linear.weight, model.f_linear.bias |
|
|
|
|
|
|
|
conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy() |
|
|
|
|
|
|
|
target = np.argmax(pred.cpu().detach().numpy(), axis=1).squeeze() |
|
|
|
weights = pred_weights[target, :].cpu().detach().numpy() |
|
|
|
heatmap = conv.squeeze(0) @ weights |
|
|
|
|
|
scale = 224 / 12 |
|
plt.figure(figsize=(12, 12)) |
|
img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
plt.imshow(img) |
|
plt.imshow(zoom(heatmap, zoom=(scale, scale)), cmap='jet', alpha=0.5) |
|
|
|
if len(frame_no) == 1: |
|
frame_no = "0"+frame_no |
|
filename = video_no+frame_name+frame_no+".jpg" |
|
file_path = os.path.join(os.getcwd(), "dataset/gifs2/raw/", filename) |
|
plt.savefig(file_path) |
|
selected_frames.append(file_path) |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|