gaze-demo / demo.py
vikhyatk's picture
Update demo.py
73c3ba2 verified
raw
history blame
5.16 kB
import gradio as gr
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM
import matplotlib
matplotlib.use("Agg") # Use Agg backend for non-interactive plotting
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
model = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cuda"},
revision="69420e0c6596863b4f0059e365fadc5cb388e8fd"
)
def visualize_gaze_multi(face_boxes, gaze_points, image=None, show_plot=True):
"""Visualization function with reduced whitespace"""
# Calculate figure size based on image aspect ratio
if image is not None:
height, width = image.shape[:2]
aspect_ratio = width / height
fig_height = 6 # Base height
fig_width = fig_height * aspect_ratio
else:
width, height = 800, 600
fig_width, fig_height = 10, 8
# Create figure with tight layout
fig = plt.figure(figsize=(fig_width, fig_height))
ax = fig.add_subplot(111)
if image is not None:
ax.imshow(image)
else:
ax.set_facecolor("#1a1a1a")
fig.patch.set_facecolor("#1a1a1a")
colors = plt.cm.rainbow(np.linspace(0, 1, len(face_boxes)))
for face_box, gaze_point, color in zip(face_boxes, gaze_points, colors):
hex_color = "#{:02x}{:02x}{:02x}".format(
int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)
)
x, y, width_box, height_box = face_box
gaze_x, gaze_y = gaze_point
face_center_x = x + width_box / 2
face_center_y = y + height_box / 2
face_rect = plt.Rectangle(
(x, y), width_box, height_box, fill=False, color=hex_color, linewidth=2
)
ax.add_patch(face_rect)
points = 50
alphas = np.linspace(0.8, 0, points)
x_points = np.linspace(face_center_x, gaze_x, points)
y_points = np.linspace(face_center_y, gaze_y, points)
for i in range(points - 1):
ax.plot(
[x_points[i], x_points[i + 1]],
[y_points[i], y_points[i + 1]],
color=hex_color,
alpha=alphas[i],
linewidth=4,
)
ax.scatter(gaze_x, gaze_y, color=hex_color, s=100, zorder=5)
ax.scatter(gaze_x, gaze_y, color="white", s=50, zorder=6)
# Set plot limits and remove axes
ax.set_xlim(0, width)
ax.set_ylim(height, 0)
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])
# Remove padding around the plot
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
@spaces.GPU(duration=15)
def process_image(input_image):
try:
# Convert to PIL Image if needed
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image)
else:
pil_image = input_image
# Get image encoding
enc_image = model.encode_image(pil_image)
# Detect faces
faces = model.detect(enc_image, "face")["objects"]
if not faces:
return None, "No faces detected in the image."
# Process each face
face_boxes = []
gaze_points = []
for face in faces:
face_center = (
(face["x_min"] + face["x_max"]) / 2,
(face["y_min"] + face["y_max"]) / 2,
)
gaze = model.detect_gaze(enc_image, face_center)
if gaze is None:
continue
face_box = (
face["x_min"] * pil_image.width,
face["y_min"] * pil_image.height,
(face["x_max"] - face["x_min"]) * pil_image.width,
(face["y_max"] - face["y_min"]) * pil_image.height,
)
gaze_point = (
gaze["x"] * pil_image.width,
gaze["y"] * pil_image.height,
)
face_boxes.append(face_box)
gaze_points.append(gaze_point)
# Create visualization
image_array = np.array(pil_image)
fig = visualize_gaze_multi(
face_boxes, gaze_points, image=image_array, show_plot=False
)
return fig, f"Detected {len(faces)} faces."
except Exception as e:
return None, f"Error processing image: {str(e)}"
with gr.Blocks(title="Moondream Gaze Detection") as app:
gr.Markdown("# πŸŒ” Moondream Gaze Detection")
gr.Markdown("Upload an image to detect faces and visualize their gaze directions.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
output_text = gr.Textbox(label="Status")
output_plot = gr.Plot(label="Visualization")
input_image.change(
fn=process_image, inputs=[input_image], outputs=[output_plot, output_text]
)
gr.Examples(
examples=["demo1.jpg", "demo2.jpg", "demo3.jpg", "demo4.jpg"],
inputs=input_image,
)
if __name__ == "__main__":
app.launch()