import gradio as gr from gradio_imageslider import ImageSlider import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from PIL import Image import requests from io import BytesIO import os # GPU 설정을 CPU로 변경 birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cpu") # GPU -> CPU로 변경 transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def load_img(source, output_type="pil"): """ 이미지를 로드하는 함수. source: 파일 경로 또는 URL output_type: 'pil' 또는 'filepath' """ try: if isinstance(source, str): if source.startswith("http://") or source.startswith("https://"): response = requests.get(source) img = Image.open(BytesIO(response.content)).convert("RGBA") else: img = Image.open(source).convert("RGBA") else: img = source.convert("RGBA") if output_type == "pil": return img elif output_type == "filepath": temp_path = "temp_image.png" img.save(temp_path) return temp_path except Exception as e: raise ValueError(f"이미지를 로드하는 중 오류가 발생했습니다: {e}") def process(image): image_size = image.size # RGBA 이미지를 RGB로 변환 if image.mode == 'RGBA': image = image.convert('RGB') input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU로 변경 # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # 결과 이미지에 알파 채널 추가 result_image = image.copy() result_image.putalpha(mask) return result_image def fn(image): im = load_img(image, output_type="pil") origin = im.copy() processed_image = process(im) return processed_image, origin, processed_image def process_file(f): name_path = f.rsplit(".", 1)[0] + ".png" im = load_img(f, output_type="pil") im = im.convert("RGB") transparent = process(im) transparent.save(name_path) return name_path slider1 = ImageSlider(label="Processed Image", type="pil") slider2 = ImageSlider(label="Processed Image from URL", type="pil") image_upload = gr.Image(label="Upload an image") image_file_upload = gr.Image(label="Upload an image", type="filepath") url_input = gr.Textbox(label="Paste an image URL") output_file = gr.File(label="Output PNG File") # 다운로드 버튼 추가 download1 = gr.File(label="Download PNG", interactive=False) download2 = gr.File(label="Download PNG from URL", interactive=False) # Example images # 예시 이미지 파일이 존재하지 않을 경우 오류가 발생할 수 있으니, URL 예시만 남겨두었습니다. url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg" def fn_with_download(image): processed_image, origin, png_image = fn(image) png_path = "processed_image.png" png_image.save(png_path) return processed_image, origin, png_path def fn_url_with_download(url): im = load_img(url, output_type="pil") processed_image, origin, png_image = fn(im) png_path = "processed_image_url.png" png_image.save(png_path) return processed_image, origin, png_path tab1 = gr.Interface( fn_with_download, inputs=image_upload, outputs=[slider1, gr.Image(label="Original Image"), download1], examples=[], # 예시 이미지가 없을 경우 빈 리스트로 설정 api_name="image" ) tab2 = gr.Interface( fn_url_with_download, inputs=url_input, outputs=[slider2, gr.Image(label="Original Image from URL"), download2], examples=[url_example], api_name="text" ) tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=[], api_name="png") demo = gr.TabbedInterface( [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool" ) if __name__ == "__main__": # 포트와 호스트 설정은 필요에 따라 수정 가능합니다. demo.launch(show_error=True)