File size: 1,071 Bytes
b9cc655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from PIL import Image
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import torch.nn as nn

def calculate_seg_mask(image):
    image = Image.open(image).convert("RGB")

    processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
    model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

    class_names = {
        0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", 
        4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress", 
        8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face", 
        12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm", 
        16: "Bag", 17: "Scarf"
    }

    inputs = processor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg