qubvel-hf HF staff commited on
Commit
243b0b7
·
verified ·
1 Parent(s): 1477e95

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +45 -5
README.md CHANGED
@@ -20,10 +20,53 @@ Table of Contents:
20
  - [Dataset](#dataset)
21
 
22
  ## Load trained model
 
 
 
 
 
 
 
 
 
 
 
23
  ```python
 
 
 
 
24
  import segmentation_models_pytorch as smp
25
 
26
- model = smp.from_pretrained("smp-hub/segformer-b1-1024x1024-city-160k")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ```
28
 
29
  ## Model init parameters
@@ -40,11 +83,8 @@ model_init_params = {
40
  }
41
  ```
42
 
43
- ## Model metrics
44
- [More Information Needed]
45
-
46
  ## Dataset
47
- Dataset name: [More Information Needed]
48
 
49
  ## More Information
50
  - Library: https://github.com/qubvel/segmentation_models.pytorch
 
20
  - [Dataset](#dataset)
21
 
22
  ## Load trained model
23
+
24
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb)
25
+
26
+ 1. Install requirements.
27
+
28
+ ```bash
29
+ pip install -U segmentation_models_pytorch albumentations
30
+ ```
31
+
32
+ 2. Run inference.
33
+
34
  ```python
35
+ import torch
36
+ import requests
37
+ import numpy as np
38
+ import albumentations as A
39
  import segmentation_models_pytorch as smp
40
 
41
+ from PIL import Image
42
+
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # Load pretrained model and preprocessing function
46
+ checkpoint = "smp-hub/segformer-b1-1024x1024-city-160k"
47
+ model = smp.from_pretrained(checkpoint).eval().to(device)
48
+ preprocessing = A.Compose.from_pretrained(checkpoint)
49
+
50
+ # Load image
51
+ url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
52
+ image = Image.open(requests.get(url, stream=True).raw)
53
+
54
+ # Preprocess image
55
+ np_image = np.array(image)
56
+ normalized_image = preprocessing(image=np_image)["image"]
57
+ input_tensor = torch.as_tensor(normalized_image)
58
+ input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> BCHW
59
+ input_tensor = input_tensor.to(device)
60
+
61
+ # Perform inference
62
+ with torch.no_grad():
63
+ output_mask = model(input_tensor)
64
+
65
+ # Postprocess mask
66
+ mask = torch.nn.functional.interpolate(
67
+ output_mask, size=(image.height, image.width), mode="bilinear", align_corners=False
68
+ )
69
+ mask = mask.argmax(1).cpu().numpy() # argmax over predicted classes (channels dim)
70
  ```
71
 
72
  ## Model init parameters
 
83
  }
84
  ```
85
 
 
 
 
86
  ## Dataset
87
+ Dataset name: [Cityscapes](https://paperswithcode.com/dataset/cityscapes)
88
 
89
  ## More Information
90
  - Library: https://github.com/qubvel/segmentation_models.pytorch