nikunjkdtechnoland commited on
Commit
063372b
·
1 Parent(s): 526f250

init commit some files

Browse files
Files changed (40) hide show
  1. .gitignore +110 -0
  2. app.py +14 -0
  3. data/__init__.py +0 -0
  4. iopaint/__init__.py +23 -0
  5. iopaint/__main__.py +4 -0
  6. iopaint/api.py +396 -0
  7. iopaint/batch_processing.py +127 -0
  8. iopaint/benchmark.py +109 -0
  9. iopaint/file_manager/__init__.py +1 -0
  10. iopaint/model/__init__.py +37 -0
  11. iopaint/model/anytext/__init__.py +0 -0
  12. iopaint/model/anytext/anytext_model.py +73 -0
  13. iopaint/model/anytext/anytext_pipeline.py +403 -0
  14. iopaint/model/anytext/anytext_sd15.yaml +99 -0
  15. iopaint/model/anytext/cldm/__init__.py +0 -0
  16. iopaint/model/anytext/ldm/__init__.py +0 -0
  17. iopaint/model/anytext/ldm/models/__init__.py +0 -0
  18. iopaint/model/anytext/ldm/models/autoencoder.py +218 -0
  19. iopaint/model/anytext/ldm/models/diffusion/__init__.py +0 -0
  20. iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  21. iopaint/model/anytext/ldm/modules/__init__.py +0 -0
  22. iopaint/model/anytext/ldm/modules/attention.py +360 -0
  23. iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py +0 -0
  24. iopaint/model/anytext/ldm/modules/distributions/__init__.py +0 -0
  25. iopaint/model/anytext/ldm/modules/encoders/__init__.py +0 -0
  26. iopaint/model/anytext/ocr_recog/__init__.py +0 -0
  27. iopaint/model/base.py +418 -0
  28. iopaint/model/helper/__init__.py +0 -0
  29. iopaint/model/original_sd_configs/__init__.py +19 -0
  30. iopaint/model/power_paint/__init__.py +0 -0
  31. iopaint/plugins/__init__.py +74 -0
  32. iopaint/plugins/anime_seg.py +462 -0
  33. iopaint/plugins/base_plugin.py +30 -0
  34. iopaint/plugins/segment_anything/__init__.py +14 -0
  35. iopaint/plugins/segment_anything/modeling/__init__.py +11 -0
  36. iopaint/plugins/segment_anything/utils/__init__.py +5 -0
  37. iopaint/tests/.gitignore +2 -0
  38. iopaint/tests/__init__.py +0 -0
  39. model/__init__.py +0 -0
  40. utils/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Project ###
2
+ checkpoints/
3
+ pretrained-model/yolov8m-seg.pt
4
+
5
+
6
+ ### Python ###
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # pyenv
82
+ .python-version
83
+
84
+ # celery beat schedule file
85
+ celerybeat-schedule
86
+
87
+ # SageMath parsed files
88
+ *.sage.py
89
+
90
+ # Environments
91
+ .env
92
+ .venv
93
+ env/
94
+ venv/
95
+ ENV/
96
+ env.bak/
97
+ venv.bak/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+ .spyproject
102
+
103
+ # Rope project settings
104
+ .ropeproject
105
+
106
+ # mkdocs documentation
107
+ /site
108
+
109
+ # mypy
110
+ .mypy_cache/
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from only_gradio_server import process_images
3
+
4
+ # Create Gradio interface
5
+ iface = gr.Interface(fn=process_images,
6
+ inputs=[gr.Image(type='filepath', label='Input Image 1'),
7
+ gr.Image(type='filepath', label='Input Image 2', image_mode="RGBA"),
8
+ gr.Textbox(label='Replace Object Name')],
9
+ outputs='image',
10
+ title="Image Processing",
11
+ description="Object to Object Replacement")
12
+
13
+ # Launch Gradio interface
14
+ iface.launch()
data/__init__.py ADDED
File without changes
iopaint/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+ # https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
5
+ os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
6
+ os.environ["LRU_CACHE_CAPACITY"] = "1"
7
+ # prevent CPU memory leak when run model on GPU
8
+ # https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
9
+ # https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
10
+ os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
11
+
12
+
13
+ import warnings
14
+
15
+ warnings.simplefilter("ignore", UserWarning)
16
+
17
+
18
+ def entry_point():
19
+ # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
20
+ # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
21
+ from iopaint.cli import typer_app
22
+
23
+ typer_app()
iopaint/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from iopaint import entry_point
2
+
3
+ if __name__ == "__main__":
4
+ entry_point()
iopaint/api.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import threading
4
+ import time
5
+ import traceback
6
+ from pathlib import Path
7
+ from typing import Optional, Dict, List
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import socketio
12
+ import torch
13
+
14
+ try:
15
+ torch._C._jit_override_can_fuse_on_cpu(False)
16
+ torch._C._jit_override_can_fuse_on_gpu(False)
17
+ torch._C._jit_set_texpr_fuser_enabled(False)
18
+ torch._C._jit_set_nvfuser_enabled(False)
19
+ except:
20
+ pass
21
+
22
+
23
+ import uvicorn
24
+ from PIL import Image
25
+ from fastapi import APIRouter, FastAPI, Request, UploadFile
26
+ from fastapi.encoders import jsonable_encoder
27
+ from fastapi.exceptions import HTTPException
28
+ from fastapi.middleware.cors import CORSMiddleware
29
+ from fastapi.responses import JSONResponse, FileResponse, Response
30
+ from fastapi.staticfiles import StaticFiles
31
+ from loguru import logger
32
+ from socketio import AsyncServer
33
+
34
+ from iopaint.file_manager import FileManager
35
+ from iopaint.helper import (
36
+ load_img,
37
+ decode_base64_to_image,
38
+ pil_to_bytes,
39
+ numpy_to_bytes,
40
+ concat_alpha_channel,
41
+ gen_frontend_mask,
42
+ adjust_mask,
43
+ )
44
+ from iopaint.model.utils import torch_gc
45
+ from iopaint.model_manager import ModelManager
46
+ from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
47
+ from iopaint.plugins.base_plugin import BasePlugin
48
+ from iopaint.plugins.remove_bg import RemoveBG
49
+ from iopaint.schema import (
50
+ GenInfoResponse,
51
+ ApiConfig,
52
+ ServerConfigResponse,
53
+ SwitchModelRequest,
54
+ InpaintRequest,
55
+ RunPluginRequest,
56
+ SDSampler,
57
+ PluginInfo,
58
+ AdjustMaskRequest,
59
+ RemoveBGModel,
60
+ SwitchPluginModelRequest,
61
+ ModelInfo,
62
+ InteractiveSegModel,
63
+ RealESRGANModel,
64
+ )
65
+
66
+ CURRENT_DIR = Path(__file__).parent.absolute().resolve()
67
+ WEB_APP_DIR = CURRENT_DIR / "web_app"
68
+
69
+
70
+ def api_middleware(app: FastAPI):
71
+ rich_available = False
72
+ try:
73
+ if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
74
+ import anyio # importing just so it can be placed on silent list
75
+ import starlette # importing just so it can be placed on silent list
76
+ from rich.console import Console
77
+
78
+ console = Console()
79
+ rich_available = True
80
+ except Exception:
81
+ pass
82
+
83
+ def handle_exception(request: Request, e: Exception):
84
+ err = {
85
+ "error": type(e).__name__,
86
+ "detail": vars(e).get("detail", ""),
87
+ "body": vars(e).get("body", ""),
88
+ "errors": str(e),
89
+ }
90
+ if not isinstance(
91
+ e, HTTPException
92
+ ): # do not print backtrace on known httpexceptions
93
+ message = f"API error: {request.method}: {request.url} {err}"
94
+ if rich_available:
95
+ print(message)
96
+ console.print_exception(
97
+ show_locals=True,
98
+ max_frames=2,
99
+ extra_lines=1,
100
+ suppress=[anyio, starlette],
101
+ word_wrap=False,
102
+ width=min([console.width, 200]),
103
+ )
104
+ else:
105
+ traceback.print_exc()
106
+ return JSONResponse(
107
+ status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
108
+ )
109
+
110
+ @app.middleware("http")
111
+ async def exception_handling(request: Request, call_next):
112
+ try:
113
+ return await call_next(request)
114
+ except Exception as e:
115
+ return handle_exception(request, e)
116
+
117
+ @app.exception_handler(Exception)
118
+ async def fastapi_exception_handler(request: Request, e: Exception):
119
+ return handle_exception(request, e)
120
+
121
+ @app.exception_handler(HTTPException)
122
+ async def http_exception_handler(request: Request, e: HTTPException):
123
+ return handle_exception(request, e)
124
+
125
+ cors_options = {
126
+ "allow_methods": ["*"],
127
+ "allow_headers": ["*"],
128
+ "allow_origins": ["*"],
129
+ "allow_credentials": True,
130
+ }
131
+ app.add_middleware(CORSMiddleware, **cors_options)
132
+
133
+
134
+ global_sio: AsyncServer = None
135
+
136
+
137
+ def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
138
+ # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
139
+ # logger.info(f"diffusion callback: step={step}, timestep={timestep}")
140
+
141
+ # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
142
+ # but for now let's just start a separate event loop. It shouldn't make a difference for single person use
143
+ asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
144
+ return {}
145
+
146
+
147
+ class Api:
148
+ def __init__(self, app: FastAPI, config: ApiConfig):
149
+ self.app = app
150
+ self.config = config
151
+ self.router = APIRouter()
152
+ self.queue_lock = threading.Lock()
153
+ api_middleware(self.app)
154
+
155
+ self.file_manager = self._build_file_manager()
156
+ self.plugins = self._build_plugins()
157
+ self.model_manager = self._build_model_manager()
158
+
159
+ # fmt: off
160
+ self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
161
+ self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
162
+ self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
163
+ self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
164
+ self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
165
+ self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
166
+ self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
167
+ self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
168
+ self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
169
+ self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
170
+ self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
171
+ self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
172
+ self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
173
+ # fmt: on
174
+
175
+ global global_sio
176
+ self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
177
+ self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
178
+ self.app.mount("/ws", self.combined_asgi_app)
179
+ global_sio = self.sio
180
+
181
+ def add_api_route(self, path: str, endpoint, **kwargs):
182
+ return self.app.add_api_route(path, endpoint, **kwargs)
183
+
184
+ def api_save_image(self, file: UploadFile):
185
+ filename = file.filename
186
+ origin_image_bytes = file.file.read()
187
+ with open(self.config.output_dir / filename, "wb") as fw:
188
+ fw.write(origin_image_bytes)
189
+
190
+ def api_current_model(self) -> ModelInfo:
191
+ return self.model_manager.current_model
192
+
193
+ def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
194
+ if req.name == self.model_manager.name:
195
+ return self.model_manager.current_model
196
+ self.model_manager.switch(req.name)
197
+ return self.model_manager.current_model
198
+
199
+ def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
200
+ if req.plugin_name in self.plugins:
201
+ self.plugins[req.plugin_name].switch_model(req.model_name)
202
+ if req.plugin_name == RemoveBG.name:
203
+ self.config.remove_bg_model = req.model_name
204
+ if req.plugin_name == RealESRGANUpscaler.name:
205
+ self.config.realesrgan_model = req.model_name
206
+ if req.plugin_name == InteractiveSeg.name:
207
+ self.config.interactive_seg_model = req.model_name
208
+ torch_gc()
209
+
210
+ def api_server_config(self) -> ServerConfigResponse:
211
+ plugins = []
212
+ for it in self.plugins.values():
213
+ plugins.append(
214
+ PluginInfo(
215
+ name=it.name,
216
+ support_gen_image=it.support_gen_image,
217
+ support_gen_mask=it.support_gen_mask,
218
+ )
219
+ )
220
+
221
+ return ServerConfigResponse(
222
+ plugins=plugins,
223
+ modelInfos=self.model_manager.scan_models(),
224
+ removeBGModel=self.config.remove_bg_model,
225
+ removeBGModels=RemoveBGModel.values(),
226
+ realesrganModel=self.config.realesrgan_model,
227
+ realesrganModels=RealESRGANModel.values(),
228
+ interactiveSegModel=self.config.interactive_seg_model,
229
+ interactiveSegModels=InteractiveSegModel.values(),
230
+ enableFileManager=self.file_manager is not None,
231
+ enableAutoSaving=self.config.output_dir is not None,
232
+ enableControlnet=self.model_manager.enable_controlnet,
233
+ controlnetMethod=self.model_manager.controlnet_method,
234
+ disableModelSwitch=False,
235
+ isDesktop=False,
236
+ samplers=self.api_samplers(),
237
+ )
238
+
239
+ def api_input_image(self) -> FileResponse:
240
+ if self.config.input and self.config.input.is_file():
241
+ return FileResponse(self.config.input)
242
+ raise HTTPException(status_code=404, detail="Input image not found")
243
+
244
+ def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
245
+ _, _, info = load_img(file.file.read(), return_info=True)
246
+ parts = info.get("parameters", "").split("Negative prompt: ")
247
+ prompt = parts[0].strip()
248
+ negative_prompt = ""
249
+ if len(parts) > 1:
250
+ negative_prompt = parts[1].split("\n")[0].strip()
251
+ return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
252
+
253
+ def api_inpaint(self, req: InpaintRequest):
254
+ image, alpha_channel, infos = decode_base64_to_image(req.image)
255
+ mask, _, _ = decode_base64_to_image(req.mask, gray=True)
256
+
257
+ mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
258
+ if image.shape[:2] != mask.shape[:2]:
259
+ raise HTTPException(
260
+ 400,
261
+ detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
262
+ )
263
+
264
+ if req.paint_by_example_example_image:
265
+ paint_by_example_image, _, _ = decode_base64_to_image(
266
+ req.paint_by_example_example_image
267
+ )
268
+
269
+ start = time.time()
270
+ rgb_np_img = self.model_manager(image, mask, req)
271
+ logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
272
+ torch_gc()
273
+
274
+ rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
275
+ rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
276
+
277
+ ext = "png"
278
+ res_img_bytes = pil_to_bytes(
279
+ Image.fromarray(rgb_res),
280
+ ext=ext,
281
+ quality=self.config.quality,
282
+ infos=infos,
283
+ )
284
+
285
+ asyncio.run(self.sio.emit("diffusion_finish"))
286
+
287
+ return Response(
288
+ content=res_img_bytes,
289
+ media_type=f"image/{ext}",
290
+ headers={"X-Seed": str(req.sd_seed)},
291
+ )
292
+
293
+ def api_run_plugin_gen_image(self, req: RunPluginRequest):
294
+ ext = "png"
295
+ if req.name not in self.plugins:
296
+ raise HTTPException(status_code=422, detail="Plugin not found")
297
+ if not self.plugins[req.name].support_gen_image:
298
+ raise HTTPException(
299
+ status_code=422, detail="Plugin does not support output image"
300
+ )
301
+ rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
302
+ bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
303
+ torch_gc()
304
+
305
+ if bgr_or_rgba_np_img.shape[2] == 4:
306
+ rgba_np_img = bgr_or_rgba_np_img
307
+ else:
308
+ rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
309
+ rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
310
+
311
+ return Response(
312
+ content=pil_to_bytes(
313
+ Image.fromarray(rgba_np_img),
314
+ ext=ext,
315
+ quality=self.config.quality,
316
+ infos=infos,
317
+ ),
318
+ media_type=f"image/{ext}",
319
+ )
320
+
321
+ def api_run_plugin_gen_mask(self, req: RunPluginRequest):
322
+ if req.name not in self.plugins:
323
+ raise HTTPException(status_code=422, detail="Plugin not found")
324
+ if not self.plugins[req.name].support_gen_mask:
325
+ raise HTTPException(
326
+ status_code=422, detail="Plugin does not support output image"
327
+ )
328
+ rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
329
+ bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
330
+ torch_gc()
331
+ res_mask = gen_frontend_mask(bgr_or_gray_mask)
332
+ return Response(
333
+ content=numpy_to_bytes(res_mask, "png"),
334
+ media_type="image/png",
335
+ )
336
+
337
+ def api_samplers(self) -> List[str]:
338
+ return [member.value for member in SDSampler.__members__.values()]
339
+
340
+ def api_adjust_mask(self, req: AdjustMaskRequest):
341
+ mask, _, _ = decode_base64_to_image(req.mask, gray=True)
342
+ mask = adjust_mask(mask, req.kernel_size, req.operate)
343
+ return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
344
+
345
+ def launch(self):
346
+ self.app.include_router(self.router)
347
+ uvicorn.run(
348
+ self.combined_asgi_app,
349
+ host=self.config.host,
350
+ port=self.config.port,
351
+ timeout_keep_alive=999999999,
352
+ )
353
+
354
+ def _build_file_manager(self) -> Optional[FileManager]:
355
+ if self.config.input and self.config.input.is_dir():
356
+ logger.info(
357
+ f"Input is directory, initialize file manager {self.config.input}"
358
+ )
359
+
360
+ return FileManager(
361
+ app=self.app,
362
+ input_dir=self.config.input,
363
+ output_dir=self.config.output_dir,
364
+ )
365
+ return None
366
+
367
+ def _build_plugins(self) -> Dict[str, BasePlugin]:
368
+ return build_plugins(
369
+ self.config.enable_interactive_seg,
370
+ self.config.interactive_seg_model,
371
+ self.config.interactive_seg_device,
372
+ self.config.enable_remove_bg,
373
+ self.config.remove_bg_model,
374
+ self.config.enable_anime_seg,
375
+ self.config.enable_realesrgan,
376
+ self.config.realesrgan_device,
377
+ self.config.realesrgan_model,
378
+ self.config.enable_gfpgan,
379
+ self.config.gfpgan_device,
380
+ self.config.enable_restoreformer,
381
+ self.config.restoreformer_device,
382
+ self.config.no_half,
383
+ )
384
+
385
+ def _build_model_manager(self):
386
+ return ModelManager(
387
+ name=self.config.model,
388
+ device=torch.device(self.config.device),
389
+ no_half=self.config.no_half,
390
+ low_mem=self.config.low_mem,
391
+ disable_nsfw=self.config.disable_nsfw_checker,
392
+ sd_cpu_textencoder=self.config.cpu_textencoder,
393
+ local_files_only=self.config.local_files_only,
394
+ cpu_offload=self.config.cpu_offload,
395
+ callback=diffuser_callback,
396
+ )
iopaint/batch_processing.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Optional
4
+
5
+ import cv2
6
+ import psutil
7
+ from PIL import Image
8
+ from loguru import logger
9
+ from rich.console import Console
10
+ from rich.progress import (
11
+ Progress,
12
+ SpinnerColumn,
13
+ TimeElapsedColumn,
14
+ MofNCompleteColumn,
15
+ TextColumn,
16
+ BarColumn,
17
+ TaskProgressColumn,
18
+ )
19
+
20
+ from iopaint.helper import pil_to_bytes
21
+ from iopaint.model.utils import torch_gc
22
+ from iopaint.model_manager import ModelManager
23
+ from iopaint.schema import InpaintRequest
24
+
25
+
26
+ def glob_images(path: Path) -> Dict[str, Path]:
27
+ # png/jpg/jpeg
28
+ if path.is_file():
29
+ return {path.stem: path}
30
+ elif path.is_dir():
31
+ res = {}
32
+ for it in path.glob("*.*"):
33
+ if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
34
+ res[it.stem] = it
35
+ return res
36
+
37
+
38
+ def batch_inpaint(
39
+ model: str,
40
+ device,
41
+ image: Path,
42
+ mask: Path,
43
+ output: Path,
44
+ config: Optional[Path] = None,
45
+ concat: bool = False,
46
+ ):
47
+ if image.is_dir() and output.is_file():
48
+ logger.error(
49
+ f"invalid --output: when image is a directory, output should be a directory"
50
+ )
51
+ exit(-1)
52
+ output.mkdir(parents=True, exist_ok=True)
53
+
54
+ image_paths = glob_images(image)
55
+ mask_paths = glob_images(mask)
56
+ if len(image_paths) == 0:
57
+ logger.error(f"invalid --image: empty image folder")
58
+ exit(-1)
59
+ if len(mask_paths) == 0:
60
+ logger.error(f"invalid --mask: empty mask folder")
61
+ exit(-1)
62
+
63
+ if config is None:
64
+ inpaint_request = InpaintRequest()
65
+ logger.info(f"Using default config: {inpaint_request}")
66
+ else:
67
+ with open(config, "r", encoding="utf-8") as f:
68
+ inpaint_request = InpaintRequest(**json.load(f))
69
+
70
+ model_manager = ModelManager(name=model, device=device)
71
+ first_mask = list(mask_paths.values())[0]
72
+
73
+ console = Console()
74
+
75
+ with Progress(
76
+ SpinnerColumn(),
77
+ TextColumn("[progress.description]{task.description}"),
78
+ BarColumn(),
79
+ TaskProgressColumn(),
80
+ MofNCompleteColumn(),
81
+ TimeElapsedColumn(),
82
+ console=console,
83
+ transient=False,
84
+ ) as progress:
85
+ task = progress.add_task("Batch processing...", total=len(image_paths))
86
+ for stem, image_p in image_paths.items():
87
+ if stem not in mask_paths and mask.is_dir():
88
+ progress.log(f"mask for {image_p} not found")
89
+ progress.update(task, advance=1)
90
+ continue
91
+ mask_p = mask_paths.get(stem, first_mask)
92
+
93
+ infos = Image.open(image_p).info
94
+
95
+ img = cv2.imread(str(image_p))
96
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
97
+ mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
98
+ if mask_img.shape[:2] != img.shape[:2]:
99
+ progress.log(
100
+ f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
101
+ )
102
+ mask_img = cv2.resize(
103
+ mask_img,
104
+ (img.shape[1], img.shape[0]),
105
+ interpolation=cv2.INTER_NEAREST,
106
+ )
107
+ mask_img[mask_img >= 127] = 255
108
+ mask_img[mask_img < 127] = 0
109
+
110
+ # bgr
111
+ inpaint_result = model_manager(img, mask_img, inpaint_request)
112
+ inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
113
+ if concat:
114
+ mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
115
+ inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
116
+
117
+ img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
118
+ save_p = output / f"{stem}.png"
119
+ with open(save_p, "wb") as fw:
120
+ fw.write(img_bytes)
121
+
122
+ progress.update(task, advance=1)
123
+ torch_gc()
124
+ # pid = psutil.Process().pid
125
+ # memory_info = psutil.Process(pid).memory_info()
126
+ # memory_in_mb = memory_info.rss / (1024 * 1024)
127
+ # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
iopaint/benchmark.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ import numpy as np
8
+ import nvidia_smi
9
+ import psutil
10
+ import torch
11
+
12
+ from iopaint.model_manager import ModelManager
13
+ from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
14
+
15
+ try:
16
+ torch._C._jit_override_can_fuse_on_cpu(False)
17
+ torch._C._jit_override_can_fuse_on_gpu(False)
18
+ torch._C._jit_set_texpr_fuser_enabled(False)
19
+ torch._C._jit_set_nvfuser_enabled(False)
20
+ except:
21
+ pass
22
+
23
+ NUM_THREADS = str(4)
24
+
25
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
26
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
27
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
28
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
29
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
30
+ if os.environ.get("CACHE_DIR"):
31
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
32
+
33
+
34
+ def run_model(model, size):
35
+ # RGB
36
+ image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
37
+ mask = np.random.randint(0, 255, size).astype(np.uint8)
38
+
39
+ config = InpaintRequest(
40
+ ldm_steps=2,
41
+ hd_strategy=HDStrategy.ORIGINAL,
42
+ hd_strategy_crop_margin=128,
43
+ hd_strategy_crop_trigger_size=128,
44
+ hd_strategy_resize_limit=128,
45
+ prompt="a fox is sitting on a bench",
46
+ sd_steps=5,
47
+ sd_sampler=SDSampler.ddim,
48
+ )
49
+ model(image, mask, config)
50
+
51
+
52
+ def benchmark(model, times: int, empty_cache: bool):
53
+ sizes = [(512, 512)]
54
+
55
+ nvidia_smi.nvmlInit()
56
+ device_id = 0
57
+ handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
58
+
59
+ def format(metrics):
60
+ return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
61
+
62
+ process = psutil.Process(os.getpid())
63
+ # 每个 size 给出显存和内存占用的指标
64
+ for size in sizes:
65
+ torch.cuda.empty_cache()
66
+ time_metrics = []
67
+ cpu_metrics = []
68
+ memory_metrics = []
69
+ gpu_memory_metrics = []
70
+ for _ in range(times):
71
+ start = time.time()
72
+ run_model(model, size)
73
+ torch.cuda.synchronize()
74
+
75
+ # cpu_metrics.append(process.cpu_percent())
76
+ time_metrics.append((time.time() - start) * 1000)
77
+ memory_metrics.append(process.memory_info().rss / 1024 / 1024)
78
+ gpu_memory_metrics.append(
79
+ nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
80
+ )
81
+
82
+ print(f"size: {size}".center(80, "-"))
83
+ # print(f"cpu: {format(cpu_metrics)}")
84
+ print(f"latency: {format(time_metrics)}ms")
85
+ print(f"memory: {format(memory_metrics)} MB")
86
+ print(f"gpu memory: {format(gpu_memory_metrics)} MB")
87
+
88
+ nvidia_smi.nvmlShutdown()
89
+
90
+
91
+ def get_args_parser():
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument("--name")
94
+ parser.add_argument("--device", default="cuda", type=str)
95
+ parser.add_argument("--times", default=10, type=int)
96
+ parser.add_argument("--empty-cache", action="store_true")
97
+ return parser.parse_args()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ args = get_args_parser()
102
+ device = torch.device(args.device)
103
+ model = ModelManager(
104
+ name=args.name,
105
+ device=device,
106
+ disable_nsfw=True,
107
+ sd_cpu_textencoder=True,
108
+ )
109
+ benchmark(model, args.times, args.empty_cache)
iopaint/file_manager/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .file_manager import FileManager
iopaint/model/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .anytext.anytext_model import AnyText
2
+ from .controlnet import ControlNet
3
+ from .fcf import FcF
4
+ from .instruct_pix2pix import InstructPix2Pix
5
+ from .kandinsky import Kandinsky22
6
+ from .lama import LaMa
7
+ from .ldm import LDM
8
+ from .manga import Manga
9
+ from .mat import MAT
10
+ from .mi_gan import MIGAN
11
+ from .opencv2 import OpenCV2
12
+ from .paint_by_example import PaintByExample
13
+ from .power_paint.power_paint import PowerPaint
14
+ from .sd import SD15, SD2, Anything4, RealisticVision14, SD
15
+ from .sdxl import SDXL
16
+ from .zits import ZITS
17
+
18
+ models = {
19
+ LaMa.name: LaMa,
20
+ LDM.name: LDM,
21
+ ZITS.name: ZITS,
22
+ MAT.name: MAT,
23
+ FcF.name: FcF,
24
+ OpenCV2.name: OpenCV2,
25
+ Manga.name: Manga,
26
+ MIGAN.name: MIGAN,
27
+ SD15.name: SD15,
28
+ Anything4.name: Anything4,
29
+ RealisticVision14.name: RealisticVision14,
30
+ SD2.name: SD2,
31
+ PaintByExample.name: PaintByExample,
32
+ InstructPix2Pix.name: InstructPix2Pix,
33
+ Kandinsky22.name: Kandinsky22,
34
+ SDXL.name: SDXL,
35
+ PowerPaint.name: PowerPaint,
36
+ AnyText.name: AnyText,
37
+ }
iopaint/model/anytext/__init__.py ADDED
File without changes
iopaint/model/anytext/anytext_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ from iopaint.const import ANYTEXT_NAME
5
+ from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
6
+ from iopaint.model.base import DiffusionInpaintModel
7
+ from iopaint.model.utils import get_torch_dtype, is_local_files_only
8
+ from iopaint.schema import InpaintRequest
9
+
10
+
11
+ class AnyText(DiffusionInpaintModel):
12
+ name = ANYTEXT_NAME
13
+ pad_mod = 64
14
+ is_erase_model = False
15
+
16
+ @staticmethod
17
+ def download(local_files_only=False):
18
+ hf_hub_download(
19
+ repo_id=ANYTEXT_NAME,
20
+ filename="model_index.json",
21
+ local_files_only=local_files_only,
22
+ )
23
+ ckpt_path = hf_hub_download(
24
+ repo_id=ANYTEXT_NAME,
25
+ filename="pytorch_model.fp16.safetensors",
26
+ local_files_only=local_files_only,
27
+ )
28
+ font_path = hf_hub_download(
29
+ repo_id=ANYTEXT_NAME,
30
+ filename="SourceHanSansSC-Medium.otf",
31
+ local_files_only=local_files_only,
32
+ )
33
+ return ckpt_path, font_path
34
+
35
+ def init_model(self, device, **kwargs):
36
+ local_files_only = is_local_files_only(**kwargs)
37
+ ckpt_path, font_path = self.download(local_files_only)
38
+ use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
39
+ self.model = AnyTextPipeline(
40
+ ckpt_path=ckpt_path,
41
+ font_path=font_path,
42
+ device=device,
43
+ use_fp16=torch_dtype == torch.float16,
44
+ )
45
+ self.callback = kwargs.pop("callback", None)
46
+
47
+ def forward(self, image, mask, config: InpaintRequest):
48
+ """Input image and output image have same size
49
+ image: [H, W, C] RGB
50
+ mask: [H, W, 1] 255 means area to inpainting
51
+ return: BGR IMAGE
52
+ """
53
+ height, width = image.shape[:2]
54
+ mask = mask.astype("float32") / 255.0
55
+ masked_image = image * (1 - mask)
56
+
57
+ # list of rgb ndarray
58
+ results, rtn_code, rtn_warning = self.model(
59
+ image=image,
60
+ masked_image=masked_image,
61
+ prompt=config.prompt,
62
+ negative_prompt=config.negative_prompt,
63
+ num_inference_steps=config.sd_steps,
64
+ strength=config.sd_strength,
65
+ guidance_scale=config.sd_guidance_scale,
66
+ height=height,
67
+ width=width,
68
+ seed=config.sd_seed,
69
+ sort_priority="y",
70
+ callback=self.callback
71
+ )
72
+ inpainted_rgb_image = results[0][..., ::-1]
73
+ return inpainted_rgb_image
iopaint/model/anytext/anytext_pipeline.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AnyText: Multilingual Visual Text Generation And Editing
3
+ Paper: https://arxiv.org/abs/2311.03054
4
+ Code: https://github.com/tyxsspa/AnyText
5
+ Copyright (c) Alibaba, Inc. and its affiliates.
6
+ """
7
+ import os
8
+ from pathlib import Path
9
+
10
+ from iopaint.model.utils import set_seed
11
+ from safetensors.torch import load_file
12
+
13
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
14
+ import torch
15
+ import re
16
+ import numpy as np
17
+ import cv2
18
+ import einops
19
+ from PIL import ImageFont
20
+ from iopaint.model.anytext.cldm.model import create_model, load_state_dict
21
+ from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
22
+ from iopaint.model.anytext.utils import (
23
+ check_channels,
24
+ draw_glyph,
25
+ draw_glyph2,
26
+ )
27
+
28
+
29
+ BBOX_MAX_NUM = 8
30
+ PLACE_HOLDER = "*"
31
+ max_chars = 20
32
+
33
+ ANYTEXT_CFG = os.path.join(
34
+ os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
35
+ )
36
+
37
+
38
+ def check_limits(tensor):
39
+ float16_min = torch.finfo(torch.float16).min
40
+ float16_max = torch.finfo(torch.float16).max
41
+
42
+ # 检查张量中是否有值小于float16的最小值或大于float16的最大值
43
+ is_below_min = (tensor < float16_min).any()
44
+ is_above_max = (tensor > float16_max).any()
45
+
46
+ return is_below_min or is_above_max
47
+
48
+
49
+ class AnyTextPipeline:
50
+ def __init__(self, ckpt_path, font_path, device, use_fp16=True):
51
+ self.cfg_path = ANYTEXT_CFG
52
+ self.font_path = font_path
53
+ self.use_fp16 = use_fp16
54
+ self.device = device
55
+
56
+ self.font = ImageFont.truetype(font_path, size=60)
57
+ self.model = create_model(
58
+ self.cfg_path,
59
+ device=self.device,
60
+ use_fp16=self.use_fp16,
61
+ )
62
+ if self.use_fp16:
63
+ self.model = self.model.half()
64
+ if Path(ckpt_path).suffix == ".safetensors":
65
+ state_dict = load_file(ckpt_path, device="cpu")
66
+ else:
67
+ state_dict = load_state_dict(ckpt_path, location="cpu")
68
+ self.model.load_state_dict(state_dict, strict=False)
69
+ self.model = self.model.eval().to(self.device)
70
+ self.ddim_sampler = DDIMSampler(self.model, device=self.device)
71
+
72
+ def __call__(
73
+ self,
74
+ prompt: str,
75
+ negative_prompt: str,
76
+ image: np.ndarray,
77
+ masked_image: np.ndarray,
78
+ num_inference_steps: int,
79
+ strength: float,
80
+ guidance_scale: float,
81
+ height: int,
82
+ width: int,
83
+ seed: int,
84
+ sort_priority: str = "y",
85
+ callback=None,
86
+ ):
87
+ """
88
+
89
+ Args:
90
+ prompt:
91
+ negative_prompt:
92
+ image:
93
+ masked_image:
94
+ num_inference_steps:
95
+ strength:
96
+ guidance_scale:
97
+ height:
98
+ width:
99
+ seed:
100
+ sort_priority: x: left-right, y: top-down
101
+
102
+ Returns:
103
+ result: list of images in numpy.ndarray format
104
+ rst_code: 0: normal -1: error 1:warning
105
+ rst_info: string of error or warning
106
+
107
+ """
108
+ set_seed(seed)
109
+ str_warning = ""
110
+
111
+ mode = "text-editing"
112
+ revise_pos = False
113
+ img_count = 1
114
+ ddim_steps = num_inference_steps
115
+ w = width
116
+ h = height
117
+ strength = strength
118
+ cfg_scale = guidance_scale
119
+ eta = 0.0
120
+
121
+ prompt, texts = self.modify_prompt(prompt)
122
+ if prompt is None and texts is None:
123
+ return (
124
+ None,
125
+ -1,
126
+ "You have input Chinese prompt but the translator is not loaded!",
127
+ "",
128
+ )
129
+ n_lines = len(texts)
130
+ if mode in ["text-generation", "gen"]:
131
+ edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
132
+ elif mode in ["text-editing", "edit"]:
133
+ if masked_image is None or image is None:
134
+ return (
135
+ None,
136
+ -1,
137
+ "Reference image and position image are needed for text editing!",
138
+ "",
139
+ )
140
+ if isinstance(image, str):
141
+ image = cv2.imread(image)[..., ::-1]
142
+ assert image is not None, f"Can't read ori_image image from{image}!"
143
+ elif isinstance(image, torch.Tensor):
144
+ image = image.cpu().numpy()
145
+ else:
146
+ assert isinstance(
147
+ image, np.ndarray
148
+ ), f"Unknown format of ori_image: {type(image)}"
149
+ edit_image = image.clip(1, 255) # for mask reason
150
+ edit_image = check_channels(edit_image)
151
+ # edit_image = resize_image(
152
+ # edit_image, max_length=768
153
+ # ) # make w h multiple of 64, resize if w or h > max_length
154
+ h, w = edit_image.shape[:2] # change h, w by input ref_img
155
+ # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
156
+ if masked_image is None:
157
+ pos_imgs = np.zeros((w, h, 1))
158
+ if isinstance(masked_image, str):
159
+ masked_image = cv2.imread(masked_image)[..., ::-1]
160
+ assert (
161
+ masked_image is not None
162
+ ), f"Can't read draw_pos image from{masked_image}!"
163
+ pos_imgs = 255 - masked_image
164
+ elif isinstance(masked_image, torch.Tensor):
165
+ pos_imgs = masked_image.cpu().numpy()
166
+ else:
167
+ assert isinstance(
168
+ masked_image, np.ndarray
169
+ ), f"Unknown format of draw_pos: {type(masked_image)}"
170
+ pos_imgs = 255 - masked_image
171
+ pos_imgs = pos_imgs[..., 0:1]
172
+ pos_imgs = cv2.convertScaleAbs(pos_imgs)
173
+ _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
174
+ # seprate pos_imgs
175
+ pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
176
+ if len(pos_imgs) == 0:
177
+ pos_imgs = [np.zeros((h, w, 1))]
178
+ if len(pos_imgs) < n_lines:
179
+ if n_lines == 1 and texts[0] == " ":
180
+ pass # text-to-image without text
181
+ else:
182
+ raise RuntimeError(
183
+ f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
184
+ )
185
+ elif len(pos_imgs) > n_lines:
186
+ str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
187
+ # get pre_pos, poly_list, hint that needed for anytext
188
+ pre_pos = []
189
+ poly_list = []
190
+ for input_pos in pos_imgs:
191
+ if input_pos.mean() != 0:
192
+ input_pos = (
193
+ input_pos[..., np.newaxis]
194
+ if len(input_pos.shape) == 2
195
+ else input_pos
196
+ )
197
+ poly, pos_img = self.find_polygon(input_pos)
198
+ pre_pos += [pos_img / 255.0]
199
+ poly_list += [poly]
200
+ else:
201
+ pre_pos += [np.zeros((h, w, 1))]
202
+ poly_list += [None]
203
+ np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
204
+ # prepare info dict
205
+ info = {}
206
+ info["glyphs"] = []
207
+ info["gly_line"] = []
208
+ info["positions"] = []
209
+ info["n_lines"] = [len(texts)] * img_count
210
+ gly_pos_imgs = []
211
+ for i in range(len(texts)):
212
+ text = texts[i]
213
+ if len(text) > max_chars:
214
+ str_warning = (
215
+ f'"{text}" length > max_chars: {max_chars}, will be cut off...'
216
+ )
217
+ text = text[:max_chars]
218
+ gly_scale = 2
219
+ if pre_pos[i].mean() != 0:
220
+ gly_line = draw_glyph(self.font, text)
221
+ glyphs = draw_glyph2(
222
+ self.font,
223
+ text,
224
+ poly_list[i],
225
+ scale=gly_scale,
226
+ width=w,
227
+ height=h,
228
+ add_space=False,
229
+ )
230
+ gly_pos_img = cv2.drawContours(
231
+ glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
232
+ )
233
+ if revise_pos:
234
+ resize_gly = cv2.resize(
235
+ glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
236
+ )
237
+ new_pos = cv2.morphologyEx(
238
+ (resize_gly * 255).astype(np.uint8),
239
+ cv2.MORPH_CLOSE,
240
+ kernel=np.ones(
241
+ (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
242
+ dtype=np.uint8,
243
+ ),
244
+ iterations=1,
245
+ )
246
+ new_pos = (
247
+ new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
248
+ )
249
+ contours, _ = cv2.findContours(
250
+ new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
251
+ )
252
+ if len(contours) != 1:
253
+ str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
254
+ else:
255
+ rect = cv2.minAreaRect(contours[0])
256
+ poly = np.int0(cv2.boxPoints(rect))
257
+ pre_pos[i] = (
258
+ cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
259
+ )
260
+ gly_pos_img = cv2.drawContours(
261
+ glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
262
+ )
263
+ gly_pos_imgs += [gly_pos_img] # for show
264
+ else:
265
+ glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
266
+ gly_line = np.zeros((80, 512, 1))
267
+ gly_pos_imgs += [
268
+ np.zeros((h * gly_scale, w * gly_scale, 1))
269
+ ] # for show
270
+ pos = pre_pos[i]
271
+ info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
272
+ info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
273
+ info["positions"] += [self.arr2tensor(pos, img_count)]
274
+ # get masked_x
275
+ masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
276
+ masked_img = np.transpose(masked_img, (2, 0, 1))
277
+ masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
278
+ if self.use_fp16:
279
+ masked_img = masked_img.half()
280
+ encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
281
+ masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
282
+ if self.use_fp16:
283
+ masked_x = masked_x.half()
284
+ info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
285
+
286
+ hint = self.arr2tensor(np_hint, img_count)
287
+ cond = self.model.get_learned_conditioning(
288
+ dict(
289
+ c_concat=[hint],
290
+ c_crossattn=[[prompt] * img_count],
291
+ text_info=info,
292
+ )
293
+ )
294
+ un_cond = self.model.get_learned_conditioning(
295
+ dict(
296
+ c_concat=[hint],
297
+ c_crossattn=[[negative_prompt] * img_count],
298
+ text_info=info,
299
+ )
300
+ )
301
+ shape = (4, h // 8, w // 8)
302
+ self.model.control_scales = [strength] * 13
303
+ samples, intermediates = self.ddim_sampler.sample(
304
+ ddim_steps,
305
+ img_count,
306
+ shape,
307
+ cond,
308
+ verbose=False,
309
+ eta=eta,
310
+ unconditional_guidance_scale=cfg_scale,
311
+ unconditional_conditioning=un_cond,
312
+ callback=callback
313
+ )
314
+ if self.use_fp16:
315
+ samples = samples.half()
316
+ x_samples = self.model.decode_first_stage(samples)
317
+ x_samples = (
318
+ (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
319
+ .cpu()
320
+ .numpy()
321
+ .clip(0, 255)
322
+ .astype(np.uint8)
323
+ )
324
+ results = [x_samples[i] for i in range(img_count)]
325
+ # if (
326
+ # mode == "edit" and False
327
+ # ): # replace backgound in text editing but not ideal yet
328
+ # results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
329
+ # results = [r.clip(0, 255).astype(np.uint8) for r in results]
330
+ # if len(gly_pos_imgs) > 0 and show_debug:
331
+ # glyph_bs = np.stack(gly_pos_imgs, axis=2)
332
+ # glyph_img = np.sum(glyph_bs, axis=2) * 255
333
+ # glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
334
+ # results += [np.repeat(glyph_img, 3, axis=2)]
335
+ rst_code = 1 if str_warning else 0
336
+ return results, rst_code, str_warning
337
+
338
+ def modify_prompt(self, prompt):
339
+ prompt = prompt.replace("“", '"')
340
+ prompt = prompt.replace("”", '"')
341
+ p = '"(.*?)"'
342
+ strs = re.findall(p, prompt)
343
+ if len(strs) == 0:
344
+ strs = [" "]
345
+ else:
346
+ for s in strs:
347
+ prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
348
+ # if self.is_chinese(prompt):
349
+ # if self.trans_pipe is None:
350
+ # return None, None
351
+ # old_prompt = prompt
352
+ # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
353
+ # print(f"Translate: {old_prompt} --> {prompt}")
354
+ return prompt, strs
355
+
356
+ # def is_chinese(self, text):
357
+ # text = checker._clean_text(text)
358
+ # for char in text:
359
+ # cp = ord(char)
360
+ # if checker._is_chinese_char(cp):
361
+ # return True
362
+ # return False
363
+
364
+ def separate_pos_imgs(self, img, sort_priority, gap=102):
365
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
366
+ components = []
367
+ for label in range(1, num_labels):
368
+ component = np.zeros_like(img)
369
+ component[labels == label] = 255
370
+ components.append((component, centroids[label]))
371
+ if sort_priority == "y":
372
+ fir, sec = 1, 0 # top-down first
373
+ elif sort_priority == "x":
374
+ fir, sec = 0, 1 # left-right first
375
+ components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
376
+ sorted_components = [c[0] for c in components]
377
+ return sorted_components
378
+
379
+ def find_polygon(self, image, min_rect=False):
380
+ contours, hierarchy = cv2.findContours(
381
+ image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
382
+ )
383
+ max_contour = max(contours, key=cv2.contourArea) # get contour with max area
384
+ if min_rect:
385
+ # get minimum enclosing rectangle
386
+ rect = cv2.minAreaRect(max_contour)
387
+ poly = np.int0(cv2.boxPoints(rect))
388
+ else:
389
+ # get approximate polygon
390
+ epsilon = 0.01 * cv2.arcLength(max_contour, True)
391
+ poly = cv2.approxPolyDP(max_contour, epsilon, True)
392
+ n, _, xy = poly.shape
393
+ poly = poly.reshape(n, xy)
394
+ cv2.drawContours(image, [poly], -1, 255, -1)
395
+ return poly, image
396
+
397
+ def arr2tensor(self, arr, bs):
398
+ arr = np.transpose(arr, (2, 0, 1))
399
+ _arr = torch.from_numpy(arr.copy()).float().to(self.device)
400
+ if self.use_fp16:
401
+ _arr = _arr.half()
402
+ _arr = torch.stack([_arr for _ in range(bs)], dim=0)
403
+ return _arr
iopaint/model/anytext/anytext_sd15.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: iopaint.model.anytext.cldm.cldm.ControlLDM
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "img"
10
+ cond_stage_key: "caption"
11
+ control_key: "hint"
12
+ glyph_key: "glyphs"
13
+ position_key: "positions"
14
+ image_size: 64
15
+ channels: 4
16
+ cond_stage_trainable: true # need be true when embedding_manager is valid
17
+ conditioning_key: crossattn
18
+ monitor: val/loss_simple_ema
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ only_mid_control: False
22
+ loss_alpha: 0 # perceptual loss, 0.003
23
+ loss_beta: 0 # ctc loss
24
+ latin_weight: 1.0 # latin text line may need smaller weigth
25
+ with_step_weight: true
26
+ use_vae_upsample: true
27
+ embedding_manager_config:
28
+ target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
29
+ params:
30
+ valid: true # v6
31
+ emb_type: ocr # ocr, vit, conv
32
+ glyph_channels: 1
33
+ position_channels: 1
34
+ add_pos: false
35
+ placeholder_string: '*'
36
+
37
+ control_stage_config:
38
+ target: iopaint.model.anytext.cldm.cldm.ControlNet
39
+ params:
40
+ image_size: 32 # unused
41
+ in_channels: 4
42
+ model_channels: 320
43
+ glyph_channels: 1
44
+ position_channels: 1
45
+ attention_resolutions: [ 4, 2, 1 ]
46
+ num_res_blocks: 2
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ num_heads: 8
49
+ use_spatial_transformer: True
50
+ transformer_depth: 1
51
+ context_dim: 768
52
+ use_checkpoint: True
53
+ legacy: False
54
+
55
+ unet_config:
56
+ target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel
57
+ params:
58
+ image_size: 32 # unused
59
+ in_channels: 4
60
+ out_channels: 4
61
+ model_channels: 320
62
+ attention_resolutions: [ 4, 2, 1 ]
63
+ num_res_blocks: 2
64
+ channel_mult: [ 1, 2, 4, 4 ]
65
+ num_heads: 8
66
+ use_spatial_transformer: True
67
+ transformer_depth: 1
68
+ context_dim: 768
69
+ use_checkpoint: True
70
+ legacy: False
71
+
72
+ first_stage_config:
73
+ target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
74
+ params:
75
+ embed_dim: 4
76
+ monitor: val/rec_loss
77
+ ddconfig:
78
+ double_z: true
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult:
85
+ - 1
86
+ - 2
87
+ - 4
88
+ - 4
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
94
+
95
+ cond_stage_config:
96
+ target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
97
+ params:
98
+ version: openai/clip-vit-large-patch14
99
+ use_vision: false # v6
iopaint/model/anytext/cldm/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/models/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from contextlib import contextmanager
4
+
5
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder
6
+ from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
7
+
8
+ from iopaint.model.anytext.ldm.util import instantiate_from_config
9
+ from iopaint.model.anytext.ldm.modules.ema import LitEma
10
+
11
+
12
+ class AutoencoderKL(torch.nn.Module):
13
+ def __init__(self,
14
+ ddconfig,
15
+ lossconfig,
16
+ embed_dim,
17
+ ckpt_path=None,
18
+ ignore_keys=[],
19
+ image_key="image",
20
+ colorize_nlabels=None,
21
+ monitor=None,
22
+ ema_decay=None,
23
+ learn_logvar=False
24
+ ):
25
+ super().__init__()
26
+ self.learn_logvar = learn_logvar
27
+ self.image_key = image_key
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ assert ddconfig["double_z"]
32
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
33
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
34
+ self.embed_dim = embed_dim
35
+ if colorize_nlabels is not None:
36
+ assert type(colorize_nlabels)==int
37
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
38
+ if monitor is not None:
39
+ self.monitor = monitor
40
+
41
+ self.use_ema = ema_decay is not None
42
+ if self.use_ema:
43
+ self.ema_decay = ema_decay
44
+ assert 0. < ema_decay < 1.
45
+ self.model_ema = LitEma(self, decay=ema_decay)
46
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
47
+
48
+ if ckpt_path is not None:
49
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
50
+
51
+ def init_from_ckpt(self, path, ignore_keys=list()):
52
+ sd = torch.load(path, map_location="cpu")["state_dict"]
53
+ keys = list(sd.keys())
54
+ for k in keys:
55
+ for ik in ignore_keys:
56
+ if k.startswith(ik):
57
+ print("Deleting key {} from state_dict.".format(k))
58
+ del sd[k]
59
+ self.load_state_dict(sd, strict=False)
60
+ print(f"Restored from {path}")
61
+
62
+ @contextmanager
63
+ def ema_scope(self, context=None):
64
+ if self.use_ema:
65
+ self.model_ema.store(self.parameters())
66
+ self.model_ema.copy_to(self)
67
+ if context is not None:
68
+ print(f"{context}: Switched to EMA weights")
69
+ try:
70
+ yield None
71
+ finally:
72
+ if self.use_ema:
73
+ self.model_ema.restore(self.parameters())
74
+ if context is not None:
75
+ print(f"{context}: Restored training weights")
76
+
77
+ def on_train_batch_end(self, *args, **kwargs):
78
+ if self.use_ema:
79
+ self.model_ema(self)
80
+
81
+ def encode(self, x):
82
+ h = self.encoder(x)
83
+ moments = self.quant_conv(h)
84
+ posterior = DiagonalGaussianDistribution(moments)
85
+ return posterior
86
+
87
+ def decode(self, z):
88
+ z = self.post_quant_conv(z)
89
+ dec = self.decoder(z)
90
+ return dec
91
+
92
+ def forward(self, input, sample_posterior=True):
93
+ posterior = self.encode(input)
94
+ if sample_posterior:
95
+ z = posterior.sample()
96
+ else:
97
+ z = posterior.mode()
98
+ dec = self.decode(z)
99
+ return dec, posterior
100
+
101
+ def get_input(self, batch, k):
102
+ x = batch[k]
103
+ if len(x.shape) == 3:
104
+ x = x[..., None]
105
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
106
+ return x
107
+
108
+ def training_step(self, batch, batch_idx, optimizer_idx):
109
+ inputs = self.get_input(batch, self.image_key)
110
+ reconstructions, posterior = self(inputs)
111
+
112
+ if optimizer_idx == 0:
113
+ # train encoder+decoder+logvar
114
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
115
+ last_layer=self.get_last_layer(), split="train")
116
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
117
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
118
+ return aeloss
119
+
120
+ if optimizer_idx == 1:
121
+ # train the discriminator
122
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
123
+ last_layer=self.get_last_layer(), split="train")
124
+
125
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
126
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
127
+ return discloss
128
+
129
+ def validation_step(self, batch, batch_idx):
130
+ log_dict = self._validation_step(batch, batch_idx)
131
+ with self.ema_scope():
132
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
133
+ return log_dict
134
+
135
+ def _validation_step(self, batch, batch_idx, postfix=""):
136
+ inputs = self.get_input(batch, self.image_key)
137
+ reconstructions, posterior = self(inputs)
138
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
139
+ last_layer=self.get_last_layer(), split="val"+postfix)
140
+
141
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
142
+ last_layer=self.get_last_layer(), split="val"+postfix)
143
+
144
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
145
+ self.log_dict(log_dict_ae)
146
+ self.log_dict(log_dict_disc)
147
+ return self.log_dict
148
+
149
+ def configure_optimizers(self):
150
+ lr = self.learning_rate
151
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
152
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
153
+ if self.learn_logvar:
154
+ print(f"{self.__class__.__name__}: Learning logvar")
155
+ ae_params_list.append(self.loss.logvar)
156
+ opt_ae = torch.optim.Adam(ae_params_list,
157
+ lr=lr, betas=(0.5, 0.9))
158
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
159
+ lr=lr, betas=(0.5, 0.9))
160
+ return [opt_ae, opt_disc], []
161
+
162
+ def get_last_layer(self):
163
+ return self.decoder.conv_out.weight
164
+
165
+ @torch.no_grad()
166
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
167
+ log = dict()
168
+ x = self.get_input(batch, self.image_key)
169
+ x = x.to(self.device)
170
+ if not only_inputs:
171
+ xrec, posterior = self(x)
172
+ if x.shape[1] > 3:
173
+ # colorize with random projection
174
+ assert xrec.shape[1] > 3
175
+ x = self.to_rgb(x)
176
+ xrec = self.to_rgb(xrec)
177
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
178
+ log["reconstructions"] = xrec
179
+ if log_ema or self.use_ema:
180
+ with self.ema_scope():
181
+ xrec_ema, posterior_ema = self(x)
182
+ if x.shape[1] > 3:
183
+ # colorize with random projection
184
+ assert xrec_ema.shape[1] > 3
185
+ xrec_ema = self.to_rgb(xrec_ema)
186
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
187
+ log["reconstructions_ema"] = xrec_ema
188
+ log["inputs"] = x
189
+ return log
190
+
191
+ def to_rgb(self, x):
192
+ assert self.image_key == "segmentation"
193
+ if not hasattr(self, "colorize"):
194
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
195
+ x = F.conv2d(x, weight=self.colorize)
196
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
197
+ return x
198
+
199
+
200
+ class IdentityFirstStage(torch.nn.Module):
201
+ def __init__(self, *args, vq_interface=False, **kwargs):
202
+ self.vq_interface = vq_interface
203
+ super().__init__()
204
+
205
+ def encode(self, x, *args, **kwargs):
206
+ return x
207
+
208
+ def decode(self, x, *args, **kwargs):
209
+ return x
210
+
211
+ def quantize(self, x, *args, **kwargs):
212
+ if self.vq_interface:
213
+ return x, None, [None, None, None]
214
+ return x
215
+
216
+ def forward(self, x, *args, **kwargs):
217
+ return x
218
+
iopaint/model/anytext/ldm/models/diffusion/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
iopaint/model/anytext/ldm/modules/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/modules/attention.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
10
+
11
+
12
+ # CrossAttn precision handling
13
+ import os
14
+
15
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def uniq(arr):
23
+ return {el: True for el in arr}.keys()
24
+
25
+
26
+ def default(val, d):
27
+ if exists(val):
28
+ return val
29
+ return d() if isfunction(d) else d
30
+
31
+
32
+ def max_neg_value(t):
33
+ return -torch.finfo(t.dtype).max
34
+
35
+
36
+ def init_(tensor):
37
+ dim = tensor.shape[-1]
38
+ std = 1 / math.sqrt(dim)
39
+ tensor.uniform_(-std, std)
40
+ return tensor
41
+
42
+
43
+ # feedforward
44
+ class GEGLU(nn.Module):
45
+ def __init__(self, dim_in, dim_out):
46
+ super().__init__()
47
+ self.proj = nn.Linear(dim_in, dim_out * 2)
48
+
49
+ def forward(self, x):
50
+ x, gate = self.proj(x).chunk(2, dim=-1)
51
+ return x * F.gelu(gate)
52
+
53
+
54
+ class FeedForward(nn.Module):
55
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
56
+ super().__init__()
57
+ inner_dim = int(dim * mult)
58
+ dim_out = default(dim_out, dim)
59
+ project_in = (
60
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
61
+ if not glu
62
+ else GEGLU(dim, inner_dim)
63
+ )
64
+
65
+ self.net = nn.Sequential(
66
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.net(x)
71
+
72
+
73
+ def zero_module(module):
74
+ """
75
+ Zero out the parameters of a module and return it.
76
+ """
77
+ for p in module.parameters():
78
+ p.detach().zero_()
79
+ return module
80
+
81
+
82
+ def Normalize(in_channels):
83
+ return torch.nn.GroupNorm(
84
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
85
+ )
86
+
87
+
88
+ class SpatialSelfAttention(nn.Module):
89
+ def __init__(self, in_channels):
90
+ super().__init__()
91
+ self.in_channels = in_channels
92
+
93
+ self.norm = Normalize(in_channels)
94
+ self.q = torch.nn.Conv2d(
95
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
96
+ )
97
+ self.k = torch.nn.Conv2d(
98
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
99
+ )
100
+ self.v = torch.nn.Conv2d(
101
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
102
+ )
103
+ self.proj_out = torch.nn.Conv2d(
104
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
105
+ )
106
+
107
+ def forward(self, x):
108
+ h_ = x
109
+ h_ = self.norm(h_)
110
+ q = self.q(h_)
111
+ k = self.k(h_)
112
+ v = self.v(h_)
113
+
114
+ # compute attention
115
+ b, c, h, w = q.shape
116
+ q = rearrange(q, "b c h w -> b (h w) c")
117
+ k = rearrange(k, "b c h w -> b c (h w)")
118
+ w_ = torch.einsum("bij,bjk->bik", q, k)
119
+
120
+ w_ = w_ * (int(c) ** (-0.5))
121
+ w_ = torch.nn.functional.softmax(w_, dim=2)
122
+
123
+ # attend to values
124
+ v = rearrange(v, "b c h w -> b c (h w)")
125
+ w_ = rearrange(w_, "b i j -> b j i")
126
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
127
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
128
+ h_ = self.proj_out(h_)
129
+
130
+ return x + h_
131
+
132
+
133
+ class CrossAttention(nn.Module):
134
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
135
+ super().__init__()
136
+ inner_dim = dim_head * heads
137
+ context_dim = default(context_dim, query_dim)
138
+
139
+ self.scale = dim_head**-0.5
140
+ self.heads = heads
141
+
142
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
143
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
144
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
145
+
146
+ self.to_out = nn.Sequential(
147
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
148
+ )
149
+
150
+ def forward(self, x, context=None, mask=None):
151
+ h = self.heads
152
+
153
+ q = self.to_q(x)
154
+ context = default(context, x)
155
+ k = self.to_k(context)
156
+ v = self.to_v(context)
157
+
158
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
159
+
160
+ # force cast to fp32 to avoid overflowing
161
+ if _ATTN_PRECISION == "fp32":
162
+ with torch.autocast(enabled=False, device_type="cuda"):
163
+ q, k = q.float(), k.float()
164
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
165
+ else:
166
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
167
+
168
+ del q, k
169
+
170
+ if exists(mask):
171
+ mask = rearrange(mask, "b ... -> b (...)")
172
+ max_neg_value = -torch.finfo(sim.dtype).max
173
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
174
+ sim.masked_fill_(~mask, max_neg_value)
175
+
176
+ # attention, what we cannot get enough of
177
+ sim = sim.softmax(dim=-1)
178
+
179
+ out = einsum("b i j, b j d -> b i d", sim, v)
180
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
181
+ return self.to_out(out)
182
+
183
+
184
+ class SDPACrossAttention(CrossAttention):
185
+ def forward(self, x, context=None, mask=None):
186
+ batch_size, sequence_length, inner_dim = x.shape
187
+
188
+ if mask is not None:
189
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
190
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
191
+
192
+ h = self.heads
193
+ q_in = self.to_q(x)
194
+ context = default(context, x)
195
+
196
+ k_in = self.to_k(context)
197
+ v_in = self.to_v(context)
198
+
199
+ head_dim = inner_dim // h
200
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
201
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
202
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
203
+
204
+ del q_in, k_in, v_in
205
+
206
+ dtype = q.dtype
207
+ if _ATTN_PRECISION == "fp32":
208
+ q, k, v = q.float(), k.float(), v.float()
209
+
210
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
211
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
212
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
213
+ )
214
+
215
+ hidden_states = hidden_states.transpose(1, 2).reshape(
216
+ batch_size, -1, h * head_dim
217
+ )
218
+ hidden_states = hidden_states.to(dtype)
219
+
220
+ # linear proj
221
+ hidden_states = self.to_out[0](hidden_states)
222
+ # dropout
223
+ hidden_states = self.to_out[1](hidden_states)
224
+ return hidden_states
225
+
226
+
227
+ class BasicTransformerBlock(nn.Module):
228
+ def __init__(
229
+ self,
230
+ dim,
231
+ n_heads,
232
+ d_head,
233
+ dropout=0.0,
234
+ context_dim=None,
235
+ gated_ff=True,
236
+ checkpoint=True,
237
+ disable_self_attn=False,
238
+ ):
239
+ super().__init__()
240
+
241
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
242
+ attn_cls = SDPACrossAttention
243
+ else:
244
+ attn_cls = CrossAttention
245
+
246
+ self.disable_self_attn = disable_self_attn
247
+ self.attn1 = attn_cls(
248
+ query_dim=dim,
249
+ heads=n_heads,
250
+ dim_head=d_head,
251
+ dropout=dropout,
252
+ context_dim=context_dim if self.disable_self_attn else None,
253
+ ) # is a self-attention if not self.disable_self_attn
254
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
255
+ self.attn2 = attn_cls(
256
+ query_dim=dim,
257
+ context_dim=context_dim,
258
+ heads=n_heads,
259
+ dim_head=d_head,
260
+ dropout=dropout,
261
+ ) # is self-attn if context is none
262
+ self.norm1 = nn.LayerNorm(dim)
263
+ self.norm2 = nn.LayerNorm(dim)
264
+ self.norm3 = nn.LayerNorm(dim)
265
+ self.checkpoint = checkpoint
266
+
267
+ def forward(self, x, context=None):
268
+ return checkpoint(
269
+ self._forward, (x, context), self.parameters(), self.checkpoint
270
+ )
271
+
272
+ def _forward(self, x, context=None):
273
+ x = (
274
+ self.attn1(
275
+ self.norm1(x), context=context if self.disable_self_attn else None
276
+ )
277
+ + x
278
+ )
279
+ x = self.attn2(self.norm2(x), context=context) + x
280
+ x = self.ff(self.norm3(x)) + x
281
+ return x
282
+
283
+
284
+ class SpatialTransformer(nn.Module):
285
+ """
286
+ Transformer block for image-like data.
287
+ First, project the input (aka embedding)
288
+ and reshape to b, t, d.
289
+ Then apply standard transformer action.
290
+ Finally, reshape to image
291
+ NEW: use_linear for more efficiency instead of the 1x1 convs
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ in_channels,
297
+ n_heads,
298
+ d_head,
299
+ depth=1,
300
+ dropout=0.0,
301
+ context_dim=None,
302
+ disable_self_attn=False,
303
+ use_linear=False,
304
+ use_checkpoint=True,
305
+ ):
306
+ super().__init__()
307
+ if exists(context_dim) and not isinstance(context_dim, list):
308
+ context_dim = [context_dim]
309
+ self.in_channels = in_channels
310
+ inner_dim = n_heads * d_head
311
+ self.norm = Normalize(in_channels)
312
+ if not use_linear:
313
+ self.proj_in = nn.Conv2d(
314
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
+ )
316
+ else:
317
+ self.proj_in = nn.Linear(in_channels, inner_dim)
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ BasicTransformerBlock(
322
+ inner_dim,
323
+ n_heads,
324
+ d_head,
325
+ dropout=dropout,
326
+ context_dim=context_dim[d],
327
+ disable_self_attn=disable_self_attn,
328
+ checkpoint=use_checkpoint,
329
+ )
330
+ for d in range(depth)
331
+ ]
332
+ )
333
+ if not use_linear:
334
+ self.proj_out = zero_module(
335
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
336
+ )
337
+ else:
338
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
339
+ self.use_linear = use_linear
340
+
341
+ def forward(self, x, context=None):
342
+ # note: if no context is given, cross-attention defaults to self-attention
343
+ if not isinstance(context, list):
344
+ context = [context]
345
+ b, c, h, w = x.shape
346
+ x_in = x
347
+ x = self.norm(x)
348
+ if not self.use_linear:
349
+ x = self.proj_in(x)
350
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
351
+ if self.use_linear:
352
+ x = self.proj_in(x)
353
+ for i, block in enumerate(self.transformer_blocks):
354
+ x = block(x, context=context[i])
355
+ if self.use_linear:
356
+ x = self.proj_out(x)
357
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
358
+ if not self.use_linear:
359
+ x = self.proj_out(x)
360
+ return x + x_in
iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/modules/distributions/__init__.py ADDED
File without changes
iopaint/model/anytext/ldm/modules/encoders/__init__.py ADDED
File without changes
iopaint/model/anytext/ocr_recog/__init__.py ADDED
File without changes
iopaint/model/base.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ from loguru import logger
8
+
9
+ from iopaint.helper import (
10
+ boxes_from_mask,
11
+ resize_max_size,
12
+ pad_img_to_modulo,
13
+ switch_mps_device,
14
+ )
15
+ from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
16
+ from .helper.g_diffuser_bot import expand_image
17
+ from .utils import get_scheduler
18
+
19
+
20
+ class InpaintModel:
21
+ name = "base"
22
+ min_size: Optional[int] = None
23
+ pad_mod = 8
24
+ pad_to_square = False
25
+ is_erase_model = False
26
+
27
+ def __init__(self, device, **kwargs):
28
+ """
29
+
30
+ Args:
31
+ device:
32
+ """
33
+ device = switch_mps_device(self.name, device)
34
+ self.device = device
35
+ self.init_model(device, **kwargs)
36
+
37
+ @abc.abstractmethod
38
+ def init_model(self, device, **kwargs):
39
+ ...
40
+
41
+ @staticmethod
42
+ @abc.abstractmethod
43
+ def is_downloaded() -> bool:
44
+ return False
45
+
46
+ @abc.abstractmethod
47
+ def forward(self, image, mask, config: InpaintRequest):
48
+ """Input images and output images have same size
49
+ images: [H, W, C] RGB
50
+ masks: [H, W, 1] 255 为 masks 区域
51
+ return: BGR IMAGE
52
+ """
53
+ ...
54
+
55
+ @staticmethod
56
+ def download():
57
+ ...
58
+
59
+ def _pad_forward(self, image, mask, config: InpaintRequest):
60
+ origin_height, origin_width = image.shape[:2]
61
+ pad_image = pad_img_to_modulo(
62
+ image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
63
+ )
64
+ pad_mask = pad_img_to_modulo(
65
+ mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
66
+ )
67
+
68
+ # logger.info(f"final forward pad size: {pad_image.shape}")
69
+
70
+ image, mask = self.forward_pre_process(image, mask, config)
71
+
72
+ result = self.forward(pad_image, pad_mask, config)
73
+ result = result[0:origin_height, 0:origin_width, :]
74
+
75
+ result, image, mask = self.forward_post_process(result, image, mask, config)
76
+
77
+ if config.sd_keep_unmasked_area:
78
+ mask = mask[:, :, np.newaxis]
79
+ result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
80
+ return result
81
+
82
+ def forward_pre_process(self, image, mask, config):
83
+ return image, mask
84
+
85
+ def forward_post_process(self, result, image, mask, config):
86
+ return result, image, mask
87
+
88
+ @torch.no_grad()
89
+ def __call__(self, image, mask, config: InpaintRequest):
90
+ """
91
+ images: [H, W, C] RGB, not normalized
92
+ masks: [H, W]
93
+ return: BGR IMAGE
94
+ """
95
+ inpaint_result = None
96
+ # logger.info(f"hd_strategy: {config.hd_strategy}")
97
+ if config.hd_strategy == HDStrategy.CROP:
98
+ if max(image.shape) > config.hd_strategy_crop_trigger_size:
99
+ logger.info(f"Run crop strategy")
100
+ boxes = boxes_from_mask(mask)
101
+ crop_result = []
102
+ for box in boxes:
103
+ crop_image, crop_box = self._run_box(image, mask, box, config)
104
+ crop_result.append((crop_image, crop_box))
105
+
106
+ inpaint_result = image[:, :, ::-1]
107
+ for crop_image, crop_box in crop_result:
108
+ x1, y1, x2, y2 = crop_box
109
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
110
+
111
+ elif config.hd_strategy == HDStrategy.RESIZE:
112
+ if max(image.shape) > config.hd_strategy_resize_limit:
113
+ origin_size = image.shape[:2]
114
+ downsize_image = resize_max_size(
115
+ image, size_limit=config.hd_strategy_resize_limit
116
+ )
117
+ downsize_mask = resize_max_size(
118
+ mask, size_limit=config.hd_strategy_resize_limit
119
+ )
120
+
121
+ logger.info(
122
+ f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
123
+ )
124
+ inpaint_result = self._pad_forward(
125
+ downsize_image, downsize_mask, config
126
+ )
127
+
128
+ # only paste masked area result
129
+ inpaint_result = cv2.resize(
130
+ inpaint_result,
131
+ (origin_size[1], origin_size[0]),
132
+ interpolation=cv2.INTER_CUBIC,
133
+ )
134
+ original_pixel_indices = mask < 127
135
+ inpaint_result[original_pixel_indices] = image[:, :, ::-1][
136
+ original_pixel_indices
137
+ ]
138
+
139
+ if inpaint_result is None:
140
+ inpaint_result = self._pad_forward(image, mask, config)
141
+
142
+ return inpaint_result
143
+
144
+ def _crop_box(self, image, mask, box, config: InpaintRequest):
145
+ """
146
+
147
+ Args:
148
+ image: [H, W, C] RGB
149
+ mask: [H, W, 1]
150
+ box: [left,top,right,bottom]
151
+
152
+ Returns:
153
+ BGR IMAGE, (l, r, r, b)
154
+ """
155
+ box_h = box[3] - box[1]
156
+ box_w = box[2] - box[0]
157
+ cx = (box[0] + box[2]) // 2
158
+ cy = (box[1] + box[3]) // 2
159
+ img_h, img_w = image.shape[:2]
160
+
161
+ w = box_w + config.hd_strategy_crop_margin * 2
162
+ h = box_h + config.hd_strategy_crop_margin * 2
163
+
164
+ _l = cx - w // 2
165
+ _r = cx + w // 2
166
+ _t = cy - h // 2
167
+ _b = cy + h // 2
168
+
169
+ l = max(_l, 0)
170
+ r = min(_r, img_w)
171
+ t = max(_t, 0)
172
+ b = min(_b, img_h)
173
+
174
+ # try to get more context when crop around image edge
175
+ if _l < 0:
176
+ r += abs(_l)
177
+ if _r > img_w:
178
+ l -= _r - img_w
179
+ if _t < 0:
180
+ b += abs(_t)
181
+ if _b > img_h:
182
+ t -= _b - img_h
183
+
184
+ l = max(l, 0)
185
+ r = min(r, img_w)
186
+ t = max(t, 0)
187
+ b = min(b, img_h)
188
+
189
+ crop_img = image[t:b, l:r, :]
190
+ crop_mask = mask[t:b, l:r]
191
+
192
+ # logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
193
+
194
+ return crop_img, crop_mask, [l, t, r, b]
195
+
196
+ def _calculate_cdf(self, histogram):
197
+ cdf = histogram.cumsum()
198
+ normalized_cdf = cdf / float(cdf.max())
199
+ return normalized_cdf
200
+
201
+ def _calculate_lookup(self, source_cdf, reference_cdf):
202
+ lookup_table = np.zeros(256)
203
+ lookup_val = 0
204
+ for source_index, source_val in enumerate(source_cdf):
205
+ for reference_index, reference_val in enumerate(reference_cdf):
206
+ if reference_val >= source_val:
207
+ lookup_val = reference_index
208
+ break
209
+ lookup_table[source_index] = lookup_val
210
+ return lookup_table
211
+
212
+ def _match_histograms(self, source, reference, mask):
213
+ transformed_channels = []
214
+ if len(mask.shape) == 3:
215
+ mask = mask[:, :, -1]
216
+
217
+ for channel in range(source.shape[-1]):
218
+ source_channel = source[:, :, channel]
219
+ reference_channel = reference[:, :, channel]
220
+
221
+ # only calculate histograms for non-masked parts
222
+ source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
223
+ reference_histogram, _ = np.histogram(
224
+ reference_channel[mask == 0], 256, [0, 256]
225
+ )
226
+
227
+ source_cdf = self._calculate_cdf(source_histogram)
228
+ reference_cdf = self._calculate_cdf(reference_histogram)
229
+
230
+ lookup = self._calculate_lookup(source_cdf, reference_cdf)
231
+
232
+ transformed_channels.append(cv2.LUT(source_channel, lookup))
233
+
234
+ result = cv2.merge(transformed_channels)
235
+ result = cv2.convertScaleAbs(result)
236
+
237
+ return result
238
+
239
+ def _apply_cropper(self, image, mask, config: InpaintRequest):
240
+ img_h, img_w = image.shape[:2]
241
+ l, t, w, h = (
242
+ config.croper_x,
243
+ config.croper_y,
244
+ config.croper_width,
245
+ config.croper_height,
246
+ )
247
+ r = l + w
248
+ b = t + h
249
+
250
+ l = max(l, 0)
251
+ r = min(r, img_w)
252
+ t = max(t, 0)
253
+ b = min(b, img_h)
254
+
255
+ crop_img = image[t:b, l:r, :]
256
+ crop_mask = mask[t:b, l:r]
257
+ return crop_img, crop_mask, (l, t, r, b)
258
+
259
+ def _run_box(self, image, mask, box, config: InpaintRequest):
260
+ """
261
+
262
+ Args:
263
+ image: [H, W, C] RGB
264
+ mask: [H, W, 1]
265
+ box: [left,top,right,bottom]
266
+
267
+ Returns:
268
+ BGR IMAGE
269
+ """
270
+ crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
271
+
272
+ return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
273
+
274
+
275
+ class DiffusionInpaintModel(InpaintModel):
276
+ def __init__(self, device, **kwargs):
277
+ self.model_info = kwargs["model_info"]
278
+ self.model_id_or_path = self.model_info.path
279
+ super().__init__(device, **kwargs)
280
+
281
+ @torch.no_grad()
282
+ def __call__(self, image, mask, config: InpaintRequest):
283
+ """
284
+ images: [H, W, C] RGB, not normalized
285
+ masks: [H, W]
286
+ return: BGR IMAGE
287
+ """
288
+ # boxes = boxes_from_mask(mask)
289
+ if config.use_croper:
290
+ crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
291
+ crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
292
+ inpaint_result = image[:, :, ::-1]
293
+ inpaint_result[t:b, l:r, :] = crop_image
294
+ elif config.use_extender:
295
+ inpaint_result = self._do_outpainting(image, config)
296
+ else:
297
+ inpaint_result = self._scaled_pad_forward(image, mask, config)
298
+
299
+ return inpaint_result
300
+
301
+ def _do_outpainting(self, image, config: InpaintRequest):
302
+ # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
303
+ # 从 image 中 crop 出 outpainting 区域
304
+ image_h, image_w = image.shape[:2]
305
+ cropper_l = config.extender_x
306
+ cropper_t = config.extender_y
307
+ cropper_r = config.extender_x + config.extender_width
308
+ cropper_b = config.extender_y + config.extender_height
309
+ image_l = 0
310
+ image_t = 0
311
+ image_r = image_w
312
+ image_b = image_h
313
+
314
+ # 类似求 IOU
315
+ l = max(cropper_l, image_l)
316
+ t = max(cropper_t, image_t)
317
+ r = min(cropper_r, image_r)
318
+ b = min(cropper_b, image_b)
319
+
320
+ assert (
321
+ 0 <= l < r and 0 <= t < b
322
+ ), f"cropper and image not overlap, {l},{t},{r},{b}"
323
+
324
+ cropped_image = image[t:b, l:r, :]
325
+ padding_l = max(0, image_l - cropper_l)
326
+ padding_t = max(0, image_t - cropper_t)
327
+ padding_r = max(0, cropper_r - image_r)
328
+ padding_b = max(0, cropper_b - image_b)
329
+
330
+ expanded_image, mask_image = expand_image(
331
+ cropped_image,
332
+ left=padding_l,
333
+ top=padding_t,
334
+ right=padding_r,
335
+ bottom=padding_b,
336
+ softness=config.sd_outpainting_softness,
337
+ space=config.sd_outpainting_space,
338
+ )
339
+
340
+ # 最终扩大了的 image, BGR
341
+ expanded_cropped_result_image = self._scaled_pad_forward(
342
+ expanded_image, mask_image, config
343
+ )
344
+
345
+ # RGB -> BGR
346
+ outpainting_image = cv2.copyMakeBorder(
347
+ image,
348
+ left=padding_l,
349
+ top=padding_t,
350
+ right=padding_r,
351
+ bottom=padding_b,
352
+ borderType=cv2.BORDER_CONSTANT,
353
+ value=0,
354
+ )[:, :, ::-1]
355
+
356
+ # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
357
+ paste_t = 0 if config.extender_y < 0 else config.extender_y
358
+ paste_l = 0 if config.extender_x < 0 else config.extender_x
359
+
360
+ outpainting_image[
361
+ paste_t : paste_t + expanded_cropped_result_image.shape[0],
362
+ paste_l : paste_l + expanded_cropped_result_image.shape[1],
363
+ :,
364
+ ] = expanded_cropped_result_image
365
+ return outpainting_image
366
+
367
+ def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
368
+ longer_side_length = int(config.sd_scale * max(image.shape[:2]))
369
+ origin_size = image.shape[:2]
370
+ downsize_image = resize_max_size(image, size_limit=longer_side_length)
371
+ downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
372
+ if config.sd_scale != 1:
373
+ logger.info(
374
+ f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
375
+ )
376
+ inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
377
+ # only paste masked area result
378
+ inpaint_result = cv2.resize(
379
+ inpaint_result,
380
+ (origin_size[1], origin_size[0]),
381
+ interpolation=cv2.INTER_CUBIC,
382
+ )
383
+
384
+ # blend result, copy from g_diffuser_bot
385
+ # mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
386
+ # inpaint_result = np.clip(
387
+ # inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
388
+ # )
389
+ # original_pixel_indices = mask < 127
390
+ # inpaint_result[original_pixel_indices] = image[:, :, ::-1][
391
+ # original_pixel_indices
392
+ # ]
393
+ return inpaint_result
394
+
395
+ def set_scheduler(self, config: InpaintRequest):
396
+ scheduler_config = self.model.scheduler.config
397
+ sd_sampler = config.sd_sampler
398
+ if config.sd_lcm_lora and self.model_info.support_lcm_lora:
399
+ sd_sampler = SDSampler.lcm
400
+ logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
401
+ scheduler = get_scheduler(sd_sampler, scheduler_config)
402
+ self.model.scheduler = scheduler
403
+
404
+ def forward_pre_process(self, image, mask, config):
405
+ if config.sd_mask_blur != 0:
406
+ k = 2 * config.sd_mask_blur + 1
407
+ mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
408
+
409
+ return image, mask
410
+
411
+ def forward_post_process(self, result, image, mask, config):
412
+ if config.sd_match_histograms:
413
+ result = self._match_histograms(result, image[:, :, ::-1], mask)
414
+
415
+ if config.sd_mask_blur != 0:
416
+ k = 2 * config.sd_mask_blur + 1
417
+ mask = cv2.GaussianBlur(mask, (k, k), 0)
418
+ return result, image, mask
iopaint/model/helper/__init__.py ADDED
File without changes
iopaint/model/original_sd_configs/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict
3
+
4
+ CURRENT_DIR = Path(__file__).parent.absolute()
5
+
6
+
7
+ def get_config_files() -> Dict[str, Path]:
8
+ """
9
+ - `v1`: Config file for Stable Diffusion v1
10
+ - `v2`: Config file for Stable Diffusion v2
11
+ - `xl`: Config file for Stable Diffusion XL
12
+ - `xl_refiner`: Config file for Stable Diffusion XL Refiner
13
+ """
14
+ return {
15
+ "v1": CURRENT_DIR / "v1-inference.yaml",
16
+ "v2": CURRENT_DIR / "v2-inference-v.yaml",
17
+ "xl": CURRENT_DIR / "sd_xl_base.yaml",
18
+ "xl_refiner": CURRENT_DIR / "sd_xl_refiner.yaml",
19
+ }
iopaint/model/power_paint/__init__.py ADDED
File without changes
iopaint/plugins/__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from loguru import logger
4
+
5
+ from .anime_seg import AnimeSeg
6
+ from .gfpgan_plugin import GFPGANPlugin
7
+ from .interactive_seg import InteractiveSeg
8
+ from .realesrgan import RealESRGANUpscaler
9
+ from .remove_bg import RemoveBG
10
+ from .restoreformer import RestoreFormerPlugin
11
+ from ..schema import InteractiveSegModel, Device, RealESRGANModel
12
+
13
+
14
+ def build_plugins(
15
+ enable_interactive_seg: bool,
16
+ interactive_seg_model: InteractiveSegModel,
17
+ interactive_seg_device: Device,
18
+ enable_remove_bg: bool,
19
+ remove_bg_model: str,
20
+ enable_anime_seg: bool,
21
+ enable_realesrgan: bool,
22
+ realesrgan_device: Device,
23
+ realesrgan_model: RealESRGANModel,
24
+ enable_gfpgan: bool,
25
+ gfpgan_device: Device,
26
+ enable_restoreformer: bool,
27
+ restoreformer_device: Device,
28
+ no_half: bool,
29
+ ) -> Dict:
30
+ plugins = {}
31
+ if enable_interactive_seg:
32
+ logger.info(f"Initialize {InteractiveSeg.name} plugin")
33
+ plugins[InteractiveSeg.name] = InteractiveSeg(
34
+ interactive_seg_model, interactive_seg_device
35
+ )
36
+
37
+ if enable_remove_bg:
38
+ logger.info(f"Initialize {RemoveBG.name} plugin")
39
+ plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
40
+
41
+ if enable_anime_seg:
42
+ logger.info(f"Initialize {AnimeSeg.name} plugin")
43
+ plugins[AnimeSeg.name] = AnimeSeg()
44
+
45
+ if enable_realesrgan:
46
+ logger.info(
47
+ f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
48
+ )
49
+ plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
50
+ realesrgan_model,
51
+ realesrgan_device,
52
+ no_half=no_half,
53
+ )
54
+
55
+ if enable_gfpgan:
56
+ logger.info(f"Initialize {GFPGANPlugin.name} plugin")
57
+ if enable_realesrgan:
58
+ logger.info("Use realesrgan as GFPGAN background upscaler")
59
+ else:
60
+ logger.info(
61
+ f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
62
+ )
63
+ plugins[GFPGANPlugin.name] = GFPGANPlugin(
64
+ gfpgan_device,
65
+ upscaler=plugins.get(RealESRGANUpscaler.name, None),
66
+ )
67
+
68
+ if enable_restoreformer:
69
+ logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
70
+ plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
71
+ restoreformer_device,
72
+ upscaler=plugins.get(RealESRGANUpscaler.name, None),
73
+ )
74
+ return plugins
iopaint/plugins/anime_seg.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from iopaint.helper import load_model
9
+ from iopaint.plugins.base_plugin import BasePlugin
10
+ from iopaint.schema import RunPluginRequest
11
+
12
+
13
+ class REBNCONV(nn.Module):
14
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
15
+ super(REBNCONV, self).__init__()
16
+
17
+ self.conv_s1 = nn.Conv2d(
18
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
19
+ )
20
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
21
+ self.relu_s1 = nn.ReLU(inplace=True)
22
+
23
+ def forward(self, x):
24
+ hx = x
25
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
26
+
27
+ return xout
28
+
29
+
30
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
31
+ def _upsample_like(src, tar):
32
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
33
+
34
+ return src
35
+
36
+
37
+ ### RSU-7 ###
38
+ class RSU7(nn.Module):
39
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
40
+ super(RSU7, self).__init__()
41
+
42
+ self.in_ch = in_ch
43
+ self.mid_ch = mid_ch
44
+ self.out_ch = out_ch
45
+
46
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
47
+
48
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
49
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
56
+
57
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
58
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
59
+
60
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
61
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
62
+
63
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
64
+
65
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
66
+
67
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
71
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
72
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
73
+
74
+ def forward(self, x):
75
+ b, c, h, w = x.shape
76
+
77
+ hx = x
78
+ hxin = self.rebnconvin(hx)
79
+
80
+ hx1 = self.rebnconv1(hxin)
81
+ hx = self.pool1(hx1)
82
+
83
+ hx2 = self.rebnconv2(hx)
84
+ hx = self.pool2(hx2)
85
+
86
+ hx3 = self.rebnconv3(hx)
87
+ hx = self.pool3(hx3)
88
+
89
+ hx4 = self.rebnconv4(hx)
90
+ hx = self.pool4(hx4)
91
+
92
+ hx5 = self.rebnconv5(hx)
93
+ hx = self.pool5(hx5)
94
+
95
+ hx6 = self.rebnconv6(hx)
96
+
97
+ hx7 = self.rebnconv7(hx6)
98
+
99
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
100
+ hx6dup = _upsample_like(hx6d, hx5)
101
+
102
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
103
+ hx5dup = _upsample_like(hx5d, hx4)
104
+
105
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
106
+ hx4dup = _upsample_like(hx4d, hx3)
107
+
108
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
109
+ hx3dup = _upsample_like(hx3d, hx2)
110
+
111
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
112
+ hx2dup = _upsample_like(hx2d, hx1)
113
+
114
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
115
+
116
+ return hx1d + hxin
117
+
118
+
119
+ ### RSU-6 ###
120
+ class RSU6(nn.Module):
121
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
122
+ super(RSU6, self).__init__()
123
+
124
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
125
+
126
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
127
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
128
+
129
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
130
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
134
+
135
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
136
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
137
+
138
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
139
+
140
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
141
+
142
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
146
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
147
+
148
+ def forward(self, x):
149
+ hx = x
150
+
151
+ hxin = self.rebnconvin(hx)
152
+
153
+ hx1 = self.rebnconv1(hxin)
154
+ hx = self.pool1(hx1)
155
+
156
+ hx2 = self.rebnconv2(hx)
157
+ hx = self.pool2(hx2)
158
+
159
+ hx3 = self.rebnconv3(hx)
160
+ hx = self.pool3(hx3)
161
+
162
+ hx4 = self.rebnconv4(hx)
163
+ hx = self.pool4(hx4)
164
+
165
+ hx5 = self.rebnconv5(hx)
166
+
167
+ hx6 = self.rebnconv6(hx5)
168
+
169
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
+ hx5dup = _upsample_like(hx5d, hx4)
171
+
172
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
+ hx4dup = _upsample_like(hx4d, hx3)
174
+
175
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
+ hx3dup = _upsample_like(hx3d, hx2)
177
+
178
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
+ hx2dup = _upsample_like(hx2d, hx1)
180
+
181
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
+
183
+ return hx1d + hxin
184
+
185
+
186
+ ### RSU-5 ###
187
+ class RSU5(nn.Module):
188
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
189
+ super(RSU5, self).__init__()
190
+
191
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
192
+
193
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
194
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
195
+
196
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
197
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+
204
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
205
+
206
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
210
+
211
+ def forward(self, x):
212
+ hx = x
213
+
214
+ hxin = self.rebnconvin(hx)
215
+
216
+ hx1 = self.rebnconv1(hxin)
217
+ hx = self.pool1(hx1)
218
+
219
+ hx2 = self.rebnconv2(hx)
220
+ hx = self.pool2(hx2)
221
+
222
+ hx3 = self.rebnconv3(hx)
223
+ hx = self.pool3(hx3)
224
+
225
+ hx4 = self.rebnconv4(hx)
226
+
227
+ hx5 = self.rebnconv5(hx4)
228
+
229
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
230
+ hx4dup = _upsample_like(hx4d, hx3)
231
+
232
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
233
+ hx3dup = _upsample_like(hx3d, hx2)
234
+
235
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
236
+ hx2dup = _upsample_like(hx2d, hx1)
237
+
238
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
239
+
240
+ return hx1d + hxin
241
+
242
+
243
+ ### RSU-4 ###
244
+ class RSU4(nn.Module):
245
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
246
+ super(RSU4, self).__init__()
247
+
248
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
249
+
250
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
251
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
252
+
253
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
254
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+
258
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
259
+
260
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
261
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
262
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
263
+
264
+ def forward(self, x):
265
+ hx = x
266
+
267
+ hxin = self.rebnconvin(hx)
268
+
269
+ hx1 = self.rebnconv1(hxin)
270
+ hx = self.pool1(hx1)
271
+
272
+ hx2 = self.rebnconv2(hx)
273
+ hx = self.pool2(hx2)
274
+
275
+ hx3 = self.rebnconv3(hx)
276
+
277
+ hx4 = self.rebnconv4(hx3)
278
+
279
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
280
+ hx3dup = _upsample_like(hx3d, hx2)
281
+
282
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
283
+ hx2dup = _upsample_like(hx2d, hx1)
284
+
285
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
286
+
287
+ return hx1d + hxin
288
+
289
+
290
+ ### RSU-4F ###
291
+ class RSU4F(nn.Module):
292
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
293
+ super(RSU4F, self).__init__()
294
+
295
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
296
+
297
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
298
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
299
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
300
+
301
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
302
+
303
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
304
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
305
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
306
+
307
+ def forward(self, x):
308
+ hx = x
309
+
310
+ hxin = self.rebnconvin(hx)
311
+
312
+ hx1 = self.rebnconv1(hxin)
313
+ hx2 = self.rebnconv2(hx1)
314
+ hx3 = self.rebnconv3(hx2)
315
+
316
+ hx4 = self.rebnconv4(hx3)
317
+
318
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
319
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
320
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
321
+
322
+ return hx1d + hxin
323
+
324
+
325
+ class ISNetDIS(nn.Module):
326
+ def __init__(self, in_ch=3, out_ch=1):
327
+ super(ISNetDIS, self).__init__()
328
+
329
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
330
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
331
+
332
+ self.stage1 = RSU7(64, 32, 64)
333
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
334
+
335
+ self.stage2 = RSU6(64, 32, 128)
336
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
337
+
338
+ self.stage3 = RSU5(128, 64, 256)
339
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
340
+
341
+ self.stage4 = RSU4(256, 128, 512)
342
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
343
+
344
+ self.stage5 = RSU4F(512, 256, 512)
345
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
346
+
347
+ self.stage6 = RSU4F(512, 256, 512)
348
+
349
+ # decoder
350
+ self.stage5d = RSU4F(1024, 256, 512)
351
+ self.stage4d = RSU4(1024, 128, 256)
352
+ self.stage3d = RSU5(512, 64, 128)
353
+ self.stage2d = RSU6(256, 32, 64)
354
+ self.stage1d = RSU7(128, 16, 64)
355
+
356
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
357
+
358
+ def forward(self, x):
359
+ hx = x
360
+
361
+ hxin = self.conv_in(hx)
362
+ hx = self.pool_in(hxin)
363
+
364
+ # stage 1
365
+ hx1 = self.stage1(hxin)
366
+ hx = self.pool12(hx1)
367
+
368
+ # stage 2
369
+ hx2 = self.stage2(hx)
370
+ hx = self.pool23(hx2)
371
+
372
+ # stage 3
373
+ hx3 = self.stage3(hx)
374
+ hx = self.pool34(hx3)
375
+
376
+ # stage 4
377
+ hx4 = self.stage4(hx)
378
+ hx = self.pool45(hx4)
379
+
380
+ # stage 5
381
+ hx5 = self.stage5(hx)
382
+ hx = self.pool56(hx5)
383
+
384
+ # stage 6
385
+ hx6 = self.stage6(hx)
386
+ hx6up = _upsample_like(hx6, hx5)
387
+
388
+ # -------------------- decoder --------------------
389
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
390
+ hx5dup = _upsample_like(hx5d, hx4)
391
+
392
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
393
+ hx4dup = _upsample_like(hx4d, hx3)
394
+
395
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
396
+ hx3dup = _upsample_like(hx3d, hx2)
397
+
398
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
399
+ hx2dup = _upsample_like(hx2d, hx1)
400
+
401
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
402
+
403
+ # side output
404
+ d1 = self.side1(hx1d)
405
+ d1 = _upsample_like(d1, x)
406
+ return d1.sigmoid()
407
+
408
+
409
+ # 从小到大
410
+ ANIME_SEG_MODELS = {
411
+ "url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
412
+ "md5": "5f25479076b73074730ab8de9e8f2051",
413
+ }
414
+
415
+
416
+ class AnimeSeg(BasePlugin):
417
+ # Model from: https://github.com/SkyTNT/anime-segmentation
418
+ name = "AnimeSeg"
419
+ support_gen_image = True
420
+ support_gen_mask = True
421
+
422
+ def __init__(self):
423
+ super().__init__()
424
+ self.model = load_model(
425
+ ISNetDIS(),
426
+ ANIME_SEG_MODELS["url"],
427
+ "cpu",
428
+ ANIME_SEG_MODELS["md5"],
429
+ )
430
+
431
+ def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
432
+ mask = self.forward(rgb_np_img)
433
+ mask = Image.fromarray(mask, mode="L")
434
+ h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
435
+ empty = Image.new("RGBA", (w0, h0), 0)
436
+ img = Image.fromarray(rgb_np_img)
437
+ cutout = Image.composite(img, empty, mask)
438
+ return np.asarray(cutout)
439
+
440
+ def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
441
+ return self.forward(rgb_np_img)
442
+
443
+ @torch.inference_mode()
444
+ def forward(self, rgb_np_img):
445
+ s = 1024
446
+
447
+ h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1]
448
+ if h > w:
449
+ h, w = s, int(s * w / h)
450
+ else:
451
+ h, w = int(s * h / w), s
452
+ ph, pw = s - h, s - w
453
+ tmpImg = np.zeros([s, s, 3], dtype=np.float32)
454
+ tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = (
455
+ cv2.resize(rgb_np_img, (w, h)) / 255
456
+ )
457
+ tmpImg = tmpImg.transpose((2, 0, 1))
458
+ tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor)
459
+ mask = self.model(tmpImg)
460
+ mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
461
+ mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
462
+ return (mask * 255).astype("uint8")
iopaint/plugins/base_plugin.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import numpy as np
3
+
4
+ from iopaint.schema import RunPluginRequest
5
+
6
+
7
+ class BasePlugin:
8
+ name: str
9
+ support_gen_image: bool = False
10
+ support_gen_mask: bool = False
11
+
12
+ def __init__(self):
13
+ err_msg = self.check_dep()
14
+ if err_msg:
15
+ logger.error(err_msg)
16
+ exit(-1)
17
+
18
+ def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
19
+ # return RGBA np image or BGR np image
20
+ ...
21
+
22
+ def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
23
+ # return GRAY or BGR np image, 255 means foreground, 0 means background
24
+ ...
25
+
26
+ def check_dep(self):
27
+ ...
28
+
29
+ def switch_model(self, new_model_name: str):
30
+ ...
iopaint/plugins/segment_anything/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
iopaint/plugins/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
iopaint/plugins/segment_anything/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
iopaint/tests/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *_result.png
2
+ result/
iopaint/tests/__init__.py ADDED
File without changes
model/__init__.py ADDED
File without changes
utils/__init__.py ADDED
File without changes