Edit model card

Model Overview

The SegFormer model was proposed in SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on image segmentation benchmarks such as ADE20K and Cityscapes.

` Weights are released under the MIT License. Keras model code is released under the Apache 2 License.

Links

Installation

Keras and KerasCV can be installed with:

pip install -U -q keras-cv
pip install -U -q keras>=3

Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the Keras Getting Started page.

Presets

The following model checkpoints are provided by the Keras team. Full code examples for each are available below.

Preset name Parameters Description
segformer_b0_imagenet 3.72M SegFormer model with a pretrained MiTB0 backbone.
segformer_b0 3.72M SegFormer model with MiTB0 backbone.
segformer_b1 13.68M SegFormer model with MiTB1 backbone.
segformer_b2 24.73M SegFormer model with MiTB2 backbone.
segformer_b3 44.60M SegFormer model with MiTB3 backbone.
segformer_b4 61.37M SegFormer model with MiTB4 backbone.
segformer_b5 81.97M SegFormer model with MiTB5 backbone.

Example code

import keras_cv

images = np.ones(shape=(1, 224, 224, 3))
labels = np.zeros(shape=(1, 224, 224, 1))


model = keras_cv.models.SegFormer.from_preset(
            "segformer_b0", num_classes=2
        )
# Evaluate model
model(images)

Example Usage

import keras_cv
import  keras
import numpy as np

Using the class with a backbone:

import tensorflow as tf
import keras_cv

images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("segformer_b4_cityscapes_1024")
model = keras_cv.models.segmentation.SegFormer(
    num_classes=1, backbone=backbone,
)

# Evaluate model
model(images)

# Train model
model.compile(
    optimizer="adam",
    loss=keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)

Example Usage with Hugging Face URI

import keras_cv
import  keras
import numpy as np

Using the class with a backbone:

import tensorflow as tf
import keras_cv

images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/segformer_b4_cityscapes_1024")
model = keras_cv.models.segmentation.SegFormer(
    num_classes=1, backbone=backbone,
)

# Evaluate model
model(images)

# Train model
model.compile(
    optimizer="adam",
    loss=keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
Downloads last month
6
Inference API
Unable to determine this model’s pipeline type. Check the docs .