from typing import Dict, List, Any from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel import torch import base64 import logging import numpy as np import gc from PIL import Image from io import BytesIO import subprocess logger = logging.getLogger() logger.setLevel(logging.DEBUG) class EndpointHandler: def __init__(self, path=""): # load the model self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto") logger.info(model.hf_device_map) model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"] model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"] self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map) print(subprocess.run(["nvidia-smi"])) def __call__(self, data: Any): image = data["inputs"] inputs = self.processor(image, return_tensors="pt") try: with torch.no_grad(): outputs = self.model(**inputs) print(subprocess.run(["nvidia-smi"])) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) img = Image.fromarray(output) buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return img_str.decode() except Exception as e: logger.error(str(e)) del inputs gc.collect() torch.cuda.empty_cache() return {"error": str(e)}