|
import os |
|
|
|
import albumentations as A |
|
from albumentations.pytorch.transforms import ToTensorV2 |
|
import PIL.Image |
|
import numpy as np |
|
from functools import partial |
|
from typing import Dict, List, Optional, Union |
|
from datasets import load_dataset, DatasetDict, Image |
|
from torch.utils.data import DataLoader |
|
import torch |
|
|
|
|
|
from transformers import ViTImageProcessor |
|
from transformers.image_utils import PILImageResampling, ChannelDimension |
|
from transformers.image_processing_utils import get_size_dict |
|
|
|
class ResnetImageProcessor(ViTImageProcessor): |
|
""" |
|
>>> # tfs = A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224)]) |
|
>>> # 如果传入 参数 tfs=tfs 在调用save_pretrained会报错 |
|
>>> # 本地使用 |
|
>>> mean = [0.485, 0.456, 0.406];std = [0.229, 0.224, 0.225] |
|
>>> image_processor = ResnetImageProcessor(size=(224, 224), image_mean=mean, image_std=std) |
|
>>> image_processor.save_pretrained("custom-resnet") |
|
>>> image_processor = ResnetImageProcessor.from_pretrained("custom-resnet") |
|
|
|
>>> # push_to_hub |
|
>>> # hub登录 |
|
>>> from huggingface_hub import notebook_login;notebook_login() |
|
>>> # or huggingface-cli login |
|
|
|
>>> ResnetImageProcessor.register_for_auto_class() |
|
>>> mean = [0.485, 0.456, 0.406];std = [0.229, 0.224, 0.225] |
|
>>> image_processor = ResnetImageProcessor(size=(224, 224), image_mean=mean, image_std=std) |
|
>>> image_processor.save_pretrained("custom-resnet") |
|
>>> # image_processor = ResnetImageProcessor.from_pretrained("custom-resnet") |
|
>>> # 如果要执行 push_to_hub 需要将 custom-resnet/preprocessor_config.json 中的 "image_processor_type" 改成 "ViTImageProcessor" |
|
>>> # 默认的 ResnetImageProcessor 没有注册到 AutoImageProcessor |
|
>>> # 否则从 使用 AutoImageProcessor 加载 会报错了 |
|
>>> image_processor.push_to_hub('custom-resnet') |
|
|
|
>>> # 从 huggingface hub 加载 |
|
>>> from transformers import AutoImageProcessor |
|
>>> AutoImageProcessor.register(config_class='wucng/custom-resnet/config.json',image_processor_class=ResnetImageProcessor) |
|
>>> image_processor = AutoImageProcessor.from_pretrained('wucng/custom-resnet', trust_remote_code=True) |
|
""" |
|
|
|
def resize( |
|
self, |
|
image: np.ndarray, |
|
size: Dict[str, int], |
|
resample: PILImageResampling = PILImageResampling.BILINEAR, |
|
data_format: Optional[Union[str, ChannelDimension]] = None, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
**kwargs, |
|
) -> np.ndarray: |
|
size = get_size_dict(size) |
|
output_size = (size["height"], size["width"]) |
|
height, width = size["height"], size["width"] |
|
|
|
tfs = kwargs.get('tfs', None) |
|
if tfs is None: |
|
ratio = 256 / 224 |
|
tfs = A.Compose([A.Resize(int(ratio * height), int(ratio * width)), A.CenterCrop(height, width)]) |
|
return tfs(image=image)['image'] |
|
|