huntrezz's picture
Update app.py
e687f81 verified
raw
history blame
4.3 kB
import cv2
import torch
import numpy as np
from transformers import DPTImageProcessor
import gradio as gr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch.nn as nn
from scipy.interpolate import interp2d
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load your custom trained model
class CompressedStudentModel(nn.Module):
def __init__(self):
super(CompressedStudentModel, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=3, padding=1),
)
def forward(self, x):
features = self.encoder(x)
depth = self.decoder(features)
return depth
# Initialize and load weights into the student model
model = CompressedStudentModel().to(device)
model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device))
model.eval()
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
def preprocess_image(image):
image = cv2.resize(image, (128, 72))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def plot_depth_map(depth_map, original_image):
fig = plt.figure(figsize=(32, 9))
# Increase resolution of the meshgrid
x, y = np.meshgrid(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
# Interpolate depth map
depth_interp = interp2d(np.arange(depth_map.shape[1]), np.arange(depth_map.shape[0]), depth_map)
z = depth_interp(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
# Interpolate colors
original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
colors = original_image_resized.reshape(-1, original_image_resized.shape[1], 3) / 255.0
colors_interp = interp2d(np.arange(colors.shape[1]), np.arange(colors.shape[0]), colors.reshape(-1, colors.shape[1]), kind='linear')
new_colors = colors_interp(np.linspace(0, colors.shape[1]-1, 255), np.linspace(0, colors.shape[0]-1, 255))
# Plot with depth map color
ax1 = fig.add_subplot(121, projection='3d')
surf1 = ax1.plot_surface(x, y, z, facecolors=plt.cm.viridis(z), shade=False)
ax1.set_zlim(0, 1)
ax1.view_init(elev=150, azim=90)
ax1.set_title("Depth Map Color")
plt.axis('off')
# Plot with RGB color
ax2 = fig.add_subplot(122, projection='3d')
surf2 = ax2.plot_surface(x, y, z, facecolors=new_colors, shade=False)
ax2.set_zlim(0, 1)
ax2.view_init(elev=150, azim=90)
ax2.set_title("RGB Color")
plt.axis('off')
plt.tight_layout()
plt.show()
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).squeeze().cpu().numpy()
depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
if image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return plot_depth_map(depth_map, image)
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()