Image Classification
mlx-image
Safetensors
MLX
vision

mx.expand_dims in code example does not work

#1
by chrisoffner3d - opened

I'm new to both MLX and Hugging Face, so apologies if I'm using this wrong, but the mx.expand_dims(x, 0) call in the code example

from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=518)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

model = create_model("vit_large_patch14_518.dinov2")
model.eval()

logits, attn_masks = model(x, attn_masks=True)

on the model card gives the type error

TypeError: expand_dims(): incompatible function arguments. The following argument types are supported:
    1. expand_dims(a: array, /, axis: Union[int, Sequence[int]], *, stream: Union[None, Stream, Device] = None) -> array

Invoked with types: ndarray, int

The transform call returns an numpy.ndarray array, whereas mx.expand_dims expects an mlx.core.array.
Assigning x = mx.array(x) before the mx.expand_dims call fixes this.

UjjwalK changed discussion status to closed

Sign up or log in to comment