mx.expand_dims in code example does not work
#1
by
chrisoffner3d
- opened
I'm new to both MLX and Hugging Face, so apologies if I'm using this wrong, but the mx.expand_dims(x, 0)
call in the code example
from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform
transform = ImageNetTransform(train=False, img_size=518)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)
model = create_model("vit_large_patch14_518.dinov2")
model.eval()
logits, attn_masks = model(x, attn_masks=True)
on the model card gives the type error
TypeError: expand_dims(): incompatible function arguments. The following argument types are supported:
1. expand_dims(a: array, /, axis: Union[int, Sequence[int]], *, stream: Union[None, Stream, Device] = None) -> array
Invoked with types: ndarray, int
The transform
call returns an numpy.ndarray
array, whereas mx.expand_dims
expects an mlx.core.array
.
Assigning x = mx.array(x)
before the mx.expand_dims
call fixes this.
UjjwalK
changed discussion status to
closed