Alexander Becker commited on
Commit
b139995
·
1 Parent(s): 4fd4a46
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. .gitignore +174 -0
  3. app.py +201 -4
  4. checkpoints/thera-edsr-plus.pkl +3 -0
  5. files/1_skyscrapers_2.png +3 -0
  6. files/2_manga1.png +3 -0
  7. files/3_bird.png +3 -0
  8. files/koala.png +3 -0
  9. files/manga3.png +3 -0
  10. files/raw/0853.png +3 -0
  11. files/raw/0853C.png +3 -0
  12. files/raw/69015.png +3 -0
  13. files/raw/GakuenNoise.png +3 -0
  14. files/raw/GakuenNoiseC.png +3 -0
  15. files/raw/UchiNoNyansDiary_000.png +3 -0
  16. files/raw/UchiNoNyansDiary_000_C.png +3 -0
  17. files/zebra_8.png +3 -0
  18. gradio_dualvision/__init__.py +26 -0
  19. gradio_dualvision/app_template.py +614 -0
  20. gradio_dualvision/gradio_patches/__init__.py +0 -0
  21. gradio_dualvision/gradio_patches/examples.py +36 -0
  22. gradio_dualvision/gradio_patches/gallery.py +77 -0
  23. gradio_dualvision/gradio_patches/gallery.pyi +82 -0
  24. gradio_dualvision/gradio_patches/imagesliderplus.py +156 -0
  25. gradio_dualvision/gradio_patches/imagesliderplus.pyi +161 -0
  26. gradio_dualvision/gradio_patches/radio.py +62 -0
  27. gradio_dualvision/gradio_patches/radio.pyi +67 -0
  28. gradio_dualvision/gradio_patches/templates/component/__vite-browser-external-2447137e.js +4 -0
  29. gradio_dualvision/gradio_patches/templates/component/index.js +0 -0
  30. gradio_dualvision/gradio_patches/templates/component/style.css +1 -0
  31. gradio_dualvision/gradio_patches/templates/component/wrapper-6f348d45-19fa94bf.js +2453 -0
  32. gradio_dualvision/gradio_patches/templates/example/index.js +95 -0
  33. gradio_dualvision/gradio_patches/templates/example/style.css +1 -0
  34. gradio_dualvision/version.py +25 -0
  35. model/__init__.py +2 -0
  36. model/convnext.py +55 -0
  37. model/edsr.py +122 -0
  38. model/hyper.py +41 -0
  39. model/init.py +24 -0
  40. model/rdn.py +72 -0
  41. model/swin_ir.py +532 -0
  42. model/tail.py +18 -0
  43. model/thera.py +175 -0
  44. requirements.txt +40 -0
  45. super_resolve.py +99 -0
  46. utils.py +36 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
app.py CHANGED
@@ -1,8 +1,205 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import json
3
+ import os
4
+
5
  import gradio as gr
6
+ from gradio_dualvision import DualVisionApp
7
+ from gradio_dualvision.gradio_patches.radio import Radio
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ from model import build_thera
12
+ from super_resolve import process
13
+
14
+ CHECKPOINT = "checkpoints/thera-edsr-plus.pkl"
15
+
16
+
17
+ class TheraApp(DualVisionApp):
18
+ DEFAULT_SCALE = 3.1415
19
+ DEFAULT_DO_ENSEMBLE = False
20
+
21
+ def make_header(self):
22
+ gr.Markdown(
23
+ """
24
+ ## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields
25
+ <p align="center">
26
+ <a title="Website" href="https://therasr.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
27
+ <img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
28
+ </a>
29
+ <a title="arXiv" href="https://arxiv.org/pdf/2311.17643" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
30
+ <img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
31
+ </a>
32
+ <a title="Github" href="https://github.com/prs-eth/thera" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
33
+ <img src="https://img.shields.io/github/stars/prs-eth/thera?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
34
+ </a>
35
+ </p>
36
+ <p align="center" style="margin-top: 0px;">
37
+ Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!
38
+ </p>
39
+ """
40
+ )
41
+
42
+ def build_user_components(self):
43
+ with gr.Column():
44
+ scale = gr.Slider(
45
+ label="Scaling factor",
46
+ minimum=1,
47
+ maximum=6,
48
+ step=0.01,
49
+ value=self.DEFAULT_SCALE,
50
+ )
51
+ do_ensemble = gr.Radio(
52
+ [
53
+ ("No", False),
54
+ ("Yes", True),
55
+ ],
56
+ label="Do Ensemble",
57
+ value=self.DEFAULT_DO_ENSEMBLE,
58
+ )
59
+ return {
60
+ "scale": scale,
61
+ "do_ensemble": do_ensemble,
62
+ }
63
+
64
+ def process(self, image_in: Image.Image, **kwargs):
65
+ scale = kwargs.get("scale", self.DEFAULT_SCALE)
66
+ do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE)
67
+
68
+ source = np.asarray(image_in) / 255.
69
+
70
+ # determine target shape
71
+ target_shape = (
72
+ round(source.shape[0] * scale),
73
+ round(source.shape[1] * scale),
74
+ )
75
+
76
+ # load model
77
+ with open(CHECKPOINT, 'rb') as fh:
78
+ check = pickle.load(fh)
79
+ params, backbone, size = check['model'], check['backbone'], check['size']
80
+
81
+ model = build_thera(3, backbone, size)
82
+
83
+ out = process(source, model, params, target_shape, do_ensemble=do_ensemble)
84
+ out = Image.fromarray(np.asarray(out))
85
+
86
+ nearest = image_in.resize(out.size, Image.NEAREST)
87
+
88
+ out_modalities = {
89
+ "nearest": nearest,
90
+ "out": out,
91
+ }
92
+ out_settings = {
93
+ 'scale': scale,
94
+ 'do_ensemble': do_ensemble,
95
+ }
96
+ return out_modalities, out_settings
97
+
98
+ def process_components(
99
+ self, image_in, modality_selector_left, modality_selector_right, **kwargs
100
+ ):
101
+ if image_in is None:
102
+ raise gr.Error("Input image is required")
103
+
104
+ image_settings = {}
105
+ if isinstance(image_in, str):
106
+ image_settings_path = image_in + ".settings.json"
107
+ if os.path.isfile(image_settings_path):
108
+ with open(image_settings_path, "r") as f:
109
+ image_settings = json.load(f)
110
+ image_in = Image.open(image_in).convert("RGB")
111
+ else:
112
+ if not isinstance(image_in, Image.Image):
113
+ raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
114
+ image_in = image_in.convert("RGB")
115
+ image_settings.update(kwargs)
116
+
117
+ results_dict, results_settings = self.process(image_in, **image_settings)
118
+
119
+ if not isinstance(results_dict, dict):
120
+ raise gr.Error(
121
+ f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}"
122
+ )
123
+ if len(results_dict) == 0:
124
+ raise gr.Error("`process` did not return any modalities")
125
+ for k, v in results_dict.items():
126
+ if not isinstance(k, str):
127
+ raise gr.Error(
128
+ f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}"
129
+ )
130
+ if k == self.key_original_image:
131
+ raise gr.Error(
132
+ f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
133
+ )
134
+ if not isinstance(v, Image.Image):
135
+ raise gr.Error(
136
+ f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
137
+ )
138
+ if len(results_settings) != len(self.input_keys):
139
+ raise gr.Error(
140
+ f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})"
141
+ )
142
+ if any(k not in results_settings for k in self.input_keys):
143
+ raise gr.Error(f"Mismatching setgings keys")
144
+ results_settings = {
145
+ k: cls(**ctor_args, value=results_settings[k])
146
+ for k, cls, ctor_args in zip(
147
+ self.input_keys, self.input_cls, self.input_kwargs
148
+ )
149
+ }
150
+
151
+ results_dict = {
152
+ **results_dict,
153
+ self.key_original_image: image_in,
154
+ }
155
+
156
+ results_state = [[v, k] for k, v in results_dict.items()]
157
+ modalities = list(results_dict.keys())
158
+
159
+ modality_left = (
160
+ modality_selector_left
161
+ if modality_selector_left in modalities
162
+ else modalities[0]
163
+ )
164
+ modality_right = (
165
+ modality_selector_right
166
+ if modality_selector_right in modalities
167
+ else modalities[1]
168
+ )
169
 
170
+ return [
171
+ results_state, # goes to a gr.Gallery
172
+ [
173
+ results_dict[modality_left],
174
+ results_dict[modality_right],
175
+ ], # ImageSliderPlus
176
+ Radio(
177
+ choices=modalities,
178
+ value=modality_left,
179
+ label="Left",
180
+ key="Left",
181
+ ),
182
+ Radio(
183
+ choices=modalities if self.left_selector_visible else modalities[1:],
184
+ value=modality_right,
185
+ label="Right",
186
+ key="Right",
187
+ ),
188
+ *results_settings.values(),
189
+ ]
190
 
 
 
191
 
192
+ with TheraApp(
193
+ title="Thera Arbitrary-Scale Super-Resolution",
194
+ examples_path="files",
195
+ examples_per_page=12,
196
+ squeeze_canvas=True,
197
+ advanced_settings_can_be_half_width=False,
198
+ #spaces_zero_gpu_enabled=True,
199
+ ) as demo:
200
+ demo.queue(
201
+ api_open=False,
202
+ ).launch(
203
+ server_name="0.0.0.0",
204
+ server_port=7860,
205
+ )
checkpoints/thera-edsr-plus.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a805ca6f0486d9eba8f228200340a0e6aedde16529e11fc7b98dc26d830d9aa8
3
+ size 31632862
files/1_skyscrapers_2.png ADDED

Git LFS Details

  • SHA256: fe1041d02b3decdd1eaa91c56343d561344090d21221ace4df5a62df1d4398c5
  • Pointer size: 130 Bytes
  • Size of remote file: 36.8 kB
files/2_manga1.png ADDED

Git LFS Details

  • SHA256: f5ce1cedae4b4068a41bcf35861b85bad2eaa0d7caf6a32974a0ab28d65a8952
  • Pointer size: 130 Bytes
  • Size of remote file: 41.8 kB
files/3_bird.png ADDED

Git LFS Details

  • SHA256: 14f8ec3b6b774de83cc503ebe8bfc80871d8e96a96ac5c9e1b6c4e8c04ad343b
  • Pointer size: 130 Bytes
  • Size of remote file: 56.6 kB
files/koala.png ADDED

Git LFS Details

  • SHA256: 26d15440ca6f9107e29da5659abd4b16a95cc2309596f18d4646170bad922b8f
  • Pointer size: 130 Bytes
  • Size of remote file: 20 kB
files/manga3.png ADDED

Git LFS Details

  • SHA256: df3f4abd69abdd139f7659350b690cf8bafc3db9cc4d8238d1784171f732f0f6
  • Pointer size: 130 Bytes
  • Size of remote file: 23.1 kB
files/raw/0853.png ADDED

Git LFS Details

  • SHA256: 854189ea5c2a325324500c6cf7168e0b78d30b09606eb5a59f938a5ae69139fe
  • Pointer size: 132 Bytes
  • Size of remote file: 3.7 MB
files/raw/0853C.png ADDED

Git LFS Details

  • SHA256: 2aa8dfe9a681bd4d8e04e4d16ebaaee79bcb691ff4868dd1746c8f6920aa7f0b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
files/raw/69015.png ADDED

Git LFS Details

  • SHA256: f1efe9c87a9b26ab506957d421a3e5e0c0422681b8cff4d48395008d21032ac0
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
files/raw/GakuenNoise.png ADDED

Git LFS Details

  • SHA256: 0a4d3e304ac9bc47096cc46c24e8cb87668a8bd189bf747e19d30fee8f85d9a8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.44 MB
files/raw/GakuenNoiseC.png ADDED

Git LFS Details

  • SHA256: 4cbd9ee9ad75641ed4ca4e8f8cc5f439243b0011034926c4eec667567438c0b1
  • Pointer size: 131 Bytes
  • Size of remote file: 275 kB
files/raw/UchiNoNyansDiary_000.png ADDED

Git LFS Details

  • SHA256: 955eab4bfccdc1840b966522ea1f4e701ee283ead77d2442c9afd426a0223cea
  • Pointer size: 131 Bytes
  • Size of remote file: 879 kB
files/raw/UchiNoNyansDiary_000_C.png ADDED

Git LFS Details

  • SHA256: f4b3ccb5d8124e80dc0ea8ad68776cde708d4d84822aef6e61702de0f0a771ba
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
files/zebra_8.png ADDED

Git LFS Details

  • SHA256: 3e1a57094330b4012795ff9009a2eb48f170a48c151d7bcee2a0a4eb786e0f03
  • Pointer size: 130 Bytes
  • Size of remote file: 10.7 kB
gradio_dualvision/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+
25
+ from .version import __version__
26
+ from .app_template import DualVisionApp
gradio_dualvision/app_template.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import glob
25
+ import json
26
+ import os
27
+ import re
28
+
29
+ import gradio as gr
30
+ import spaces
31
+ from PIL import Image
32
+ from gradio.components.base import Component
33
+
34
+ from .gradio_patches.examples import Examples
35
+ from .gradio_patches.gallery import Gallery
36
+ from .gradio_patches.imagesliderplus import ImageSliderPlus
37
+ from .gradio_patches.radio import Radio
38
+
39
+
40
+ class DualVisionApp(gr.Blocks):
41
+ def __init__(
42
+ self,
43
+ title,
44
+ examples_path="examples",
45
+ examples_per_page=12,
46
+ examples_cache="lazy",
47
+ squeeze_canvas=True,
48
+ squeeze_viewport_height_pct=75,
49
+ left_selector_visible=False,
50
+ advanced_settings_can_be_half_width=True,
51
+ key_original_image="Original",
52
+ spaces_zero_gpu_enabled=False,
53
+ spaces_zero_gpu_duration=None,
54
+ slider_position=0.5,
55
+ slider_line_color="#FFF",
56
+ slider_line_width="4px",
57
+ slider_arrows_color="#FFF",
58
+ slider_arrows_width="2px",
59
+ gallery_thumb_min_size="96px",
60
+ **kwargs,
61
+ ):
62
+ """
63
+ A wrapper around Gradio Blocks class that implements an image processing demo template. All the user has to do
64
+ is to subclass this class and implement two methods: `process` implementing the image processing, and
65
+ `build_user_components` implementing Gradio components reading the additional processing arguments.
66
+ Args:
67
+ title: Title of the application (str, required).
68
+ examples_path: Base path where examples will be searched (Default: `"examples"`).
69
+ examples_per_page: How many examples to show at the bottom of the app (Default: `12`).
70
+ examples_cache: Examples caching policy, corresponding to `cache_examples` argument of gradio.Examples (Default: `"lazy"`).
71
+ squeeze_canvas: When True, the image is fit to the browser viewport. When False, the image is fit to width (Default: `True`).
72
+ squeeze_viewport_height_pct: Percentage of the browser viewport height (Default: `75`).
73
+ left_selector_visible: Whether controls for changing modalities in the left part of the slider are visible (Default: `False`).
74
+ key_original_image: Name of the key under which the input image is shown in the modality selectors (Default: `"Original"`).
75
+ advanced_settings_can_be_half_width: Whether allow placing advanced settings dropdown in half-column space whenever possible (Default: `True`).
76
+ spaces_zero_gpu_enabled: When True, the app wraps the processing function with the ZeroGPU decorator.
77
+ spaces_zero_gpu_duration: Defines an integer duration in seconds passed into the ZeroGPU decorator.
78
+ slider_position: Position of the slider between 0 and 1 (Default: `0.5`).
79
+ slider_line_color: Color of the slider line (Default: `"#FFF"`).
80
+ slider_line_width: Width of the slider line (Default: `"4px"`).
81
+ slider_arrows_color: Color of the slider arrows (Default: `"#FFF"`).
82
+ slider_arrows_width: Width of the slider arrows (Default: `2px`).
83
+ gallery_thumb_min_size: Min size of the gallery thumbnail (Default: `96px`).
84
+ **kwargs: Any other arguments that Gradio Blocks class can take.
85
+ """
86
+ squeeze_viewport_height_pct = int(squeeze_viewport_height_pct)
87
+ if not 50 <= squeeze_viewport_height_pct <= 100:
88
+ raise gr.Error(
89
+ "`squeeze_viewport_height_pct` should be an integer between 50 and 100."
90
+ )
91
+ if not os.path.isdir(examples_path):
92
+ raise gr.Error("`examples_path` should be a directory.")
93
+ if not 0 <= slider_position <= 1:
94
+ raise gr.Error("`slider_position` should be between 0 and 1.")
95
+ kwargs = {k: v for k, v in kwargs.items()}
96
+ kwargs["title"] = title
97
+ self.examples_path = examples_path
98
+ self.examples_per_page = examples_per_page
99
+ self.examples_cache = examples_cache
100
+ self.key_original_image = key_original_image
101
+ self.slider_position = slider_position
102
+ self.input_keys = None
103
+ self.input_cls = None
104
+ self.input_kwargs = None
105
+ self.left_selector_visible = left_selector_visible
106
+ self.advanced_settings_can_be_half_width = advanced_settings_can_be_half_width
107
+ if spaces_zero_gpu_enabled:
108
+ self.process_components = spaces.GPU(
109
+ self.process_components, duration=spaces_zero_gpu_duration
110
+ )
111
+ self.head = ""
112
+ self.head += """
113
+ <script>
114
+ let observerFooterButtons = new MutationObserver((mutationsList, observer) => {
115
+ const oldButtonLeft = document.querySelector(".show-api");
116
+ const oldButtonRight = document.querySelector(".built-with");
117
+ if (!oldButtonRight || !oldButtonLeft) {
118
+ return;
119
+ }
120
+ observer.disconnect();
121
+
122
+ const parentDiv = oldButtonLeft.parentNode;
123
+ if (!parentDiv) return;
124
+
125
+ const createButton = (referenceButton, text, href) => {
126
+ let newButton = referenceButton.cloneNode(true);
127
+ newButton.href = href;
128
+ newButton.textContent = text;
129
+ newButton.className = referenceButton.className;
130
+ newButton.style.textDecoration = "none";
131
+ newButton.style.display = "inline-block";
132
+ newButton.style.cursor = "pointer";
133
+ return newButton;
134
+ };
135
+
136
+ const newButton0 = createButton(oldButtonRight, "Built with Gradio DualVision", "https://github.com/toshas/gradio-dualvision");
137
+ const newButton1 = createButton(oldButtonRight, "Template by Anton Obukhov", "https://www.obukhov.ai");
138
+ const newButton2 = createButton(oldButtonRight, "Licensed under CC BY-SA 4.0", "http://creativecommons.org/licenses/by-sa/4.0/");
139
+
140
+ const separatorDiv = document.createElement("div");
141
+ separatorDiv.className = "svelte-1rjryqp";
142
+ separatorDiv.textContent = "·";
143
+
144
+ parentDiv.replaceChild(newButton0, oldButtonLeft);
145
+ parentDiv.replaceChild(newButton1, oldButtonRight);
146
+ parentDiv.appendChild(separatorDiv);
147
+ parentDiv.appendChild(newButton2);
148
+ });
149
+ observerFooterButtons.observe(document.body, { childList: true, subtree: true });
150
+ </script>
151
+ """
152
+ if kwargs.get("analytics_enabled") is not False:
153
+ self.head += f"""
154
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
155
+ <script>
156
+ window.dataLayer = window.dataLayer || [];
157
+ function gtag() {{dataLayer.push(arguments);}}
158
+ gtag('js', new Date());
159
+ gtag('config', 'G-1FWSVCGZTG');
160
+ </script>
161
+ """
162
+ self.css = f"""
163
+ body {{ /* tighten the layout */
164
+ flex-grow: 0 !important;
165
+ }}
166
+ .sliderrow {{ /* center the slider */
167
+ display: flex;
168
+ justify-content: center;
169
+ }}
170
+ .slider {{ /* center the slider */
171
+ display: flex;
172
+ justify-content: center;
173
+ width: 100%;
174
+ }}
175
+ .slider .disabled {{ /* hide the main slider before image load */
176
+ visibility: hidden;
177
+ }}
178
+ .slider .svelte-9gxdi0 {{ /* hide the component label in the top-left corner before image load */
179
+ visibility: hidden;
180
+ }}
181
+ .slider .svelte-kzcjhc .icon-wrap {{
182
+ height: 0px; /* remove unnecessary spaces in captions before image load */
183
+ }}
184
+ .slider .svelte-kzcjhc.wrap {{
185
+ padding-top: 0px; /* remove unnecessary spaces in captions before image load */
186
+ }}
187
+ .slider .svelte-3w3rth {{ /* hide the dummy icon from the right pane before image load */
188
+ visibility: hidden;
189
+ }}
190
+ .svelte-106mu0e.icon-buttons {{ /* download button */
191
+ /*right: unset;*/
192
+ /*left: 8px;*/
193
+ }}
194
+ .slider .fixed {{ /* fix the opacity of the right pane image */
195
+ background-color: var(--anim-block-background-fill);
196
+ }}
197
+ .slider .inner {{ /* style slider line */
198
+ width: {slider_line_width};
199
+ background: {slider_line_color};
200
+ }}
201
+ .slider .icon-wrap svg {{ /* style slider arrows */
202
+ stroke: {slider_arrows_color};
203
+ stroke-width: {slider_arrows_width};
204
+ }}
205
+ .slider .icon-wrap path {{ /* style slider arrows */
206
+ fill: {slider_arrows_color};
207
+ }}
208
+ .row_reverse {{
209
+ flex-direction: row-reverse;
210
+ }}
211
+ .gallery.svelte-l4wpk0 {{ /* make examples gallery tiles square */
212
+ width: max({gallery_thumb_min_size}, calc(100vw / 8));
213
+ height: max({gallery_thumb_min_size}, calc(100vw / 8));
214
+ }}
215
+ .gallery.svelte-l4wpk0 img {{ /* make examples gallery tiles square */
216
+ width: max({gallery_thumb_min_size}, calc(100vw / 8));
217
+ height: max({gallery_thumb_min_size}, calc(100vw / 8));
218
+ }}
219
+ .gallery.svelte-l4wpk0 img {{ /* remove slider line from previews */
220
+ clip-path: inset(0 0 0 0);
221
+ }}
222
+ .gallery.svelte-l4wpk0 span {{ /* remove slider line from previews */
223
+ visibility: hidden;
224
+ }}
225
+ h1, h2, h3 {{ /* center markdown headings */
226
+ text-align: center;
227
+ display: block;
228
+ }}
229
+ #settings-accordion {{
230
+ margin: 0 auto;
231
+ max-width: 500px;
232
+ }}
233
+ """
234
+ if squeeze_canvas:
235
+ self.head += f"""
236
+ <script>
237
+ // fixes vertical size of the component when used inside of iframeResizer (on spaces)
238
+ function squeezeViewport() {{
239
+ if (typeof window.parentIFrame === "undefined") return;
240
+ const images = document.querySelectorAll('.slider img');
241
+ window.parentIFrame.getPageInfo((info) => {{
242
+ images.forEach((img) => {{
243
+ const imgMaxHeightNew = (info.clientHeight * {squeeze_viewport_height_pct}) / 100;
244
+ img.style.maxHeight = `${{imgMaxHeightNew}}px`;
245
+ // window.parentIFrame.size(0, null); // tighten the layout; body's flex-grow: 0 is less intrusive
246
+ }});
247
+ }});
248
+ }}
249
+ window.addEventListener('resize', squeezeViewport);
250
+
251
+ // fixes gradio-imageslider wrong position behavior when using fitting to content by triggering resize
252
+ let observer = new MutationObserver((mutationsList) => {{
253
+ const images = document.querySelectorAll('.slider img');
254
+ images.forEach((img) => {{
255
+ if (img.complete) {{
256
+ window.dispatchEvent(new Event('resize'));
257
+ }} else {{
258
+ img.onload = () => {{
259
+ window.dispatchEvent(new Event('resize'));
260
+ }}
261
+ }}
262
+ }});
263
+ }});
264
+ observer.observe(document.body, {{ childList: true, subtree: true }});
265
+ </script>
266
+ """
267
+ self.css += f"""
268
+ .slider {{ /* make the slider dimensions fit to the uploaded content dimensions */
269
+ max-width: fit-content;
270
+ }}
271
+ .slider .half-wrap {{ /* make the empty component width almost full before image load */
272
+ width: 70vw;
273
+ }}
274
+ .slider img {{ /* Ensures image fits inside the viewport */
275
+ max-height: {squeeze_viewport_height_pct}vh;
276
+ }}
277
+ """
278
+ else:
279
+ self.css += f"""
280
+ .slider .half-wrap {{ /* make the upload area full width */
281
+ width: 100%;
282
+ }}
283
+ """
284
+ kwargs["css"] = kwargs.get("css", "") + self.css
285
+ kwargs["head"] = kwargs.get("head", "") + self.head
286
+ super().__init__(**kwargs)
287
+ with self:
288
+ self.make_interface()
289
+
290
+ def process(self, image_in: Image.Image, **kwargs):
291
+ """
292
+ Process an input image into multiple modalities using the provided arguments or default settings.
293
+ Returns two dictionaries: one containing the modalities and another with the actual settings.
294
+ Override this method in a subclass.
295
+ """
296
+ raise NotImplementedError("Please override the `process` method.")
297
+
298
+ def build_user_components(self):
299
+ """
300
+ Create gradio components for the Advanced Settings dropdown, that will be passed into the `process` method.
301
+ Use gr.Row(), gr.Column(), and other context managers to arrange the components. Return them as a flat dict.
302
+ Override this method in a subclass.
303
+ """
304
+ raise NotImplementedError("Please override the `build_user_components` method.")
305
+
306
+ def discover_examples(self):
307
+ """
308
+ Looks for valid image filenames.
309
+ """
310
+ pattern = re.compile(r".*\.(jpg|JPG|jpeg|JPEG|png|PNG)$")
311
+ paths = glob.glob(f"{self.examples_path}/*")
312
+ out = list(sorted(filter(pattern.match, paths)))
313
+ return out
314
+
315
+ def process_components(
316
+ self, image_in, modality_selector_left, modality_selector_right, **kwargs
317
+ ):
318
+ """
319
+ Wraps the call to `process`. Returns results in a structure used by the gallery, slider, radio components.
320
+ """
321
+ if image_in is None:
322
+ raise gr.Error("Input image is required")
323
+
324
+ image_settings = {}
325
+ if isinstance(image_in, str):
326
+ image_settings_path = image_in + ".settings.json"
327
+ if os.path.isfile(image_settings_path):
328
+ with open(image_settings_path, "r") as f:
329
+ image_settings = json.load(f)
330
+ image_in = Image.open(image_in).convert("RGB")
331
+ else:
332
+ if not isinstance(image_in, Image.Image):
333
+ raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
334
+ image_in = image_in.convert("RGB")
335
+ image_settings.update(kwargs)
336
+
337
+ results_dict, results_settings = self.process(image_in, **image_settings)
338
+
339
+ if not isinstance(results_dict, dict):
340
+ raise gr.Error(
341
+ f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}"
342
+ )
343
+ if len(results_dict) == 0:
344
+ raise gr.Error("`process` did not return any modalities")
345
+ for k, v in results_dict.items():
346
+ if not isinstance(k, str):
347
+ raise gr.Error(
348
+ f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}"
349
+ )
350
+ if k == self.key_original_image:
351
+ raise gr.Error(
352
+ f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
353
+ )
354
+ if not isinstance(v, Image.Image):
355
+ raise gr.Error(
356
+ f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
357
+ )
358
+ if len(results_settings) != len(self.input_keys):
359
+ raise gr.Error(
360
+ f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})"
361
+ )
362
+ if any(k not in results_settings for k in self.input_keys):
363
+ raise gr.Error(f"Mismatching setgings keys")
364
+ results_settings = {
365
+ k: cls(**ctor_args, value=results_settings[k])
366
+ for k, cls, ctor_args in zip(
367
+ self.input_keys, self.input_cls, self.input_kwargs
368
+ )
369
+ }
370
+
371
+ results_dict = {
372
+ self.key_original_image: image_in,
373
+ **results_dict,
374
+ }
375
+
376
+ results_state = [[v, k] for k, v in results_dict.items()]
377
+ modalities = list(results_dict.keys())
378
+
379
+ modality_left = (
380
+ modality_selector_left
381
+ if modality_selector_left in modalities
382
+ else modalities[0]
383
+ )
384
+ modality_right = (
385
+ modality_selector_right
386
+ if modality_selector_right in modalities
387
+ else modalities[1]
388
+ )
389
+
390
+ return [
391
+ results_state, # goes to a gr.Gallery
392
+ [
393
+ results_dict[modality_left],
394
+ results_dict[modality_right],
395
+ ], # ImageSliderPlus
396
+ Radio(
397
+ choices=modalities,
398
+ value=modality_left,
399
+ label="Left",
400
+ key="Left",
401
+ ),
402
+ Radio(
403
+ choices=modalities if self.left_selector_visible else modalities[1:],
404
+ value=modality_right,
405
+ label="Right",
406
+ key="Right",
407
+ ),
408
+ *results_settings.values(),
409
+ ]
410
+
411
+ def on_process_first(
412
+ self,
413
+ image_slider,
414
+ modality_selector_left=None,
415
+ modality_selector_right=None,
416
+ *args,
417
+ ):
418
+ image_in = image_slider[0]
419
+ input_dict = {}
420
+ if len(args) > 0:
421
+ input_dict = {k: v for k, v in zip(self.input_keys, args)}
422
+ return self.process_components(
423
+ image_in, modality_selector_left, modality_selector_right, **input_dict
424
+ )
425
+
426
+ def on_process_subsequent(
427
+ self, results_state, modality_selector_left, modality_selector_right, *args
428
+ ):
429
+ if results_state is None:
430
+ raise gr.Error("Upload an image first or use an example below.")
431
+ results_state = {k: v for v, k in results_state}
432
+ image_in = results_state[self.key_original_image]
433
+ input_dict = {k: v for k, v in zip(self.input_keys, args)}
434
+ return self.process_components(
435
+ image_in, modality_selector_left, modality_selector_right, **input_dict
436
+ )
437
+
438
+ def on_selector_change_left(
439
+ self, results_state, image_slider, modality_selector_left
440
+ ):
441
+ results_state = {k: v for v, k in results_state}
442
+ return [results_state[modality_selector_left], image_slider[1]]
443
+
444
+ def on_selector_change_right(
445
+ self, results_state, image_slider, modality_selector_right
446
+ ):
447
+ results_state = {k: v for v, k in results_state}
448
+ return [image_slider[0], results_state[modality_selector_right]]
449
+
450
+ def make_interface(self):
451
+ """
452
+ Constructs the entire Gradio Blocks interface.
453
+ """
454
+ self.make_header()
455
+
456
+ results_state = Gallery(visible=False)
457
+
458
+ image_slider = self.make_slider()
459
+
460
+ with gr.Row():
461
+ modality_selector_left, modality_selector_right = (
462
+ self.make_modality_selectors(reverse_visual_order=False)
463
+ )
464
+ user_components, btn_clear, btn_submit = self.make_advanced_settings()
465
+
466
+ self.make_examples(
467
+ image_slider,
468
+ [
469
+ results_state,
470
+ image_slider,
471
+ modality_selector_left,
472
+ modality_selector_right,
473
+ *user_components.values(),
474
+ ],
475
+ )
476
+
477
+ image_slider.upload(
478
+ fn=self.on_process_first,
479
+ inputs=[
480
+ image_slider,
481
+ modality_selector_left,
482
+ modality_selector_right,
483
+ *user_components.values(),
484
+ ],
485
+ outputs=[
486
+ results_state,
487
+ image_slider,
488
+ modality_selector_left,
489
+ modality_selector_right,
490
+ *user_components.values(),
491
+ ],
492
+ )
493
+
494
+ btn_submit.click(
495
+ fn=self.on_process_subsequent,
496
+ inputs=[
497
+ results_state,
498
+ modality_selector_left,
499
+ modality_selector_right,
500
+ *user_components.values(),
501
+ ],
502
+ outputs=[
503
+ results_state,
504
+ image_slider,
505
+ modality_selector_left,
506
+ modality_selector_right,
507
+ *user_components.values(),
508
+ ],
509
+ )
510
+
511
+ btn_clear.click(
512
+ fn=lambda: (None, None),
513
+ inputs=[],
514
+ outputs=[image_slider, results_state],
515
+ )
516
+
517
+ modality_selector_left.input(
518
+ fn=self.on_selector_change_left,
519
+ inputs=[results_state, image_slider, modality_selector_left],
520
+ outputs=image_slider,
521
+ )
522
+ modality_selector_right.input(
523
+ fn=self.on_selector_change_right,
524
+ inputs=[results_state, image_slider, modality_selector_right],
525
+ outputs=image_slider,
526
+ )
527
+
528
+ def make_header(self):
529
+ """
530
+ Create a header section with Markdown and HTML.
531
+ Default: just the project title.
532
+ """
533
+ gr.Markdown(f"# {self.title}")
534
+
535
+ def make_slider(self):
536
+ with gr.Row(elem_classes="sliderrow"):
537
+ return ImageSliderPlus(
538
+ label=self.title,
539
+ type="filepath",
540
+ elem_classes="slider",
541
+ position=self.slider_position,
542
+ )
543
+
544
+ def make_modality_selectors(self, reverse_visual_order=False):
545
+ modality_selector_left = Radio(
546
+ choices=None,
547
+ value=None,
548
+ label="Left",
549
+ key="Left",
550
+ show_label=False,
551
+ container=False,
552
+ visible=self.left_selector_visible,
553
+ render=not reverse_visual_order,
554
+ )
555
+ modality_selector_right = Radio(
556
+ choices=None,
557
+ value=None,
558
+ label="Right",
559
+ key="Right",
560
+ show_label=False,
561
+ container=False,
562
+ elem_id="selector_right",
563
+ visible=False,
564
+ render=not reverse_visual_order,
565
+ )
566
+ if reverse_visual_order:
567
+ modality_selector_right.render()
568
+ modality_selector_left.render()
569
+ return modality_selector_left, modality_selector_right
570
+
571
+ def make_examples(self, inputs, outputs):
572
+ examples = self.discover_examples()
573
+ if not isinstance(examples, list):
574
+ raise gr.Error("`discover_examples` must return a list of paths")
575
+ if any(not os.path.isfile(path) for path in examples):
576
+ raise gr.Error("Not all example paths are valid files")
577
+ examples_dirname = os.path.basename(os.path.normpath(self.examples_path))
578
+ return Examples(
579
+ examples=[
580
+ (e, e) for e in examples
581
+ ], # tuples like this seem to work better with the gallery
582
+ inputs=inputs,
583
+ outputs=outputs,
584
+ examples_per_page=self.examples_per_page,
585
+ cache_examples=self.examples_cache,
586
+ fn=self.on_process_first,
587
+ directory_name=examples_dirname,
588
+ )
589
+
590
+ def make_advanced_settings(self):
591
+ with gr.Accordion("Advanced Settings", open=False, elem_id="settings-accordion"):
592
+ user_components = self.build_user_components()
593
+ if not isinstance(user_components, dict) or any(
594
+ not isinstance(k, str) or not isinstance(v, Component)
595
+ for k, v in user_components.items()
596
+ ):
597
+ raise gr.Error(
598
+ "`build_user_components` must return a dict of Gradio components with string keys. A dict of the "
599
+ "same structure will be passed into the `process` function."
600
+ )
601
+ with gr.Row():
602
+ btn_clear, btn_submit = self.make_buttons()
603
+ self.input_keys = list(user_components.keys())
604
+ self.input_cls = list(v.__class__ for v in user_components.values())
605
+ self.input_kwargs = [
606
+ {k: v for k, v in c.constructor_args.items() if k not in ("value")}
607
+ for c in user_components.values()
608
+ ]
609
+ return user_components, btn_clear, btn_submit
610
+
611
+ def make_buttons(self):
612
+ btn_clear = gr.Button("Clear")
613
+ btn_submit = gr.Button("Apply", variant="primary")
614
+ return btn_clear, btn_submit
gradio_dualvision/gradio_patches/__init__.py ADDED
File without changes
gradio_dualvision/gradio_patches/examples.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ from pathlib import Path
25
+ import gradio
26
+ from gradio.utils import get_cache_folder
27
+
28
+
29
+ class Examples(gradio.helpers.Examples):
30
+ def __init__(self, *args, directory_name=None, **kwargs):
31
+ super().__init__(*args, **kwargs, _initiated_directly=False)
32
+ if directory_name is not None:
33
+ self.cached_folder = get_cache_folder() / directory_name
34
+ self.cached_file = Path(self.cached_folder) / "log.csv"
35
+ self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
36
+ self.create()
gradio_dualvision/gradio_patches/gallery.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from gradio.components.gallery import (
5
+ GalleryImageType,
6
+ CaptionedGalleryImageType,
7
+ GalleryImage,
8
+ GalleryData,
9
+ )
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ import gradio
14
+ import numpy as np
15
+ import PIL.Image
16
+ from gradio_client.utils import is_http_url_like
17
+
18
+ from gradio import processing_utils, utils, wasm_utils
19
+ from gradio.data_classes import FileData
20
+
21
+
22
+ class Gallery(gradio.Gallery):
23
+ def postprocess(
24
+ self,
25
+ value: list[GalleryImageType | CaptionedGalleryImageType] | None,
26
+ ) -> GalleryData:
27
+ """
28
+ This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
29
+ format = "png" if img.mode == "I;16" else "webp"
30
+ """
31
+ if value is None:
32
+ return GalleryData(root=[])
33
+ output = []
34
+
35
+ def _save(img):
36
+ url = None
37
+ caption = None
38
+ orig_name = None
39
+ if isinstance(img, (tuple, list)):
40
+ img, caption = img
41
+ if isinstance(img, np.ndarray):
42
+ file = processing_utils.save_img_array_to_cache(
43
+ img, cache_dir=self.GRADIO_CACHE, format=self.format
44
+ )
45
+ file_path = str(utils.abspath(file))
46
+ elif isinstance(img, PIL.Image.Image):
47
+ format = "png" #if img.mode == "I;16" else "webp"
48
+ file = processing_utils.save_pil_to_cache(
49
+ img, cache_dir=self.GRADIO_CACHE, format=format
50
+ )
51
+ file_path = str(utils.abspath(file))
52
+ elif isinstance(img, str):
53
+ file_path = img
54
+ if is_http_url_like(img):
55
+ url = img
56
+ orig_name = Path(urlparse(img).path).name
57
+ else:
58
+ url = None
59
+ orig_name = Path(img).name
60
+ elif isinstance(img, Path):
61
+ file_path = str(img)
62
+ orig_name = img.name
63
+ else:
64
+ raise ValueError(f"Cannot process type as image: {type(img)}")
65
+ return GalleryImage(
66
+ image=FileData(path=file_path, url=url, orig_name=orig_name),
67
+ caption=caption,
68
+ )
69
+
70
+ if wasm_utils.IS_WASM:
71
+ for img in value:
72
+ output.append(_save(img))
73
+ else:
74
+ with ThreadPoolExecutor() as executor:
75
+ for o in executor.map(_save, value):
76
+ output.append(o)
77
+ return GalleryData(root=output)
gradio_dualvision/gradio_patches/gallery.pyi ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from gradio.components.gallery import (
5
+ GalleryImageType,
6
+ CaptionedGalleryImageType,
7
+ GalleryImage,
8
+ GalleryData,
9
+ )
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ import gradio
14
+ import numpy as np
15
+ import PIL.Image
16
+ from gradio_client.utils import is_http_url_like
17
+
18
+ from gradio import processing_utils, utils, wasm_utils
19
+ from gradio.data_classes import FileData
20
+
21
+ from gradio.events import Dependency
22
+
23
+ class Gallery(gradio.Gallery):
24
+ def postprocess(
25
+ self,
26
+ value: list[GalleryImageType | CaptionedGalleryImageType] | None,
27
+ ) -> GalleryData:
28
+ """
29
+ This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
30
+ format = "png" if img.mode == "I;16" else "webp"
31
+ """
32
+ if value is None:
33
+ return GalleryData(root=[])
34
+ output = []
35
+
36
+ def _save(img):
37
+ url = None
38
+ caption = None
39
+ orig_name = None
40
+ if isinstance(img, (tuple, list)):
41
+ img, caption = img
42
+ if isinstance(img, np.ndarray):
43
+ file = processing_utils.save_img_array_to_cache(
44
+ img, cache_dir=self.GRADIO_CACHE, format=self.format
45
+ )
46
+ file_path = str(utils.abspath(file))
47
+ elif isinstance(img, PIL.Image.Image):
48
+ format = "png" #if img.mode == "I;16" else "webp"
49
+ file = processing_utils.save_pil_to_cache(
50
+ img, cache_dir=self.GRADIO_CACHE, format=format
51
+ )
52
+ file_path = str(utils.abspath(file))
53
+ elif isinstance(img, str):
54
+ file_path = img
55
+ if is_http_url_like(img):
56
+ url = img
57
+ orig_name = Path(urlparse(img).path).name
58
+ else:
59
+ url = None
60
+ orig_name = Path(img).name
61
+ elif isinstance(img, Path):
62
+ file_path = str(img)
63
+ orig_name = img.name
64
+ else:
65
+ raise ValueError(f"Cannot process type as image: {type(img)}")
66
+ return GalleryImage(
67
+ image=FileData(path=file_path, url=url, orig_name=orig_name),
68
+ caption=caption,
69
+ )
70
+
71
+ if wasm_utils.IS_WASM:
72
+ for img in value:
73
+ output.append(_save(img))
74
+ else:
75
+ with ThreadPoolExecutor() as executor:
76
+ for o in executor.map(_save, value):
77
+ output.append(o)
78
+ return GalleryData(root=output)
79
+ from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
80
+ from gradio.blocks import Block
81
+ if TYPE_CHECKING:
82
+ from gradio.components import Timer
gradio_dualvision/gradio_patches/imagesliderplus.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import json
25
+ import os.path
26
+ import tempfile
27
+ from pathlib import Path
28
+ from typing import Union, Tuple, Optional
29
+
30
+ import numpy as np
31
+ from PIL import Image
32
+ from gradio import processing_utils
33
+ from gradio import utils
34
+ from gradio.data_classes import FileData, GradioRootModel, JsonData
35
+ from gradio_client import utils as client_utils
36
+ from gradio_imageslider import ImageSlider
37
+ from gradio_imageslider.imageslider import image_tuple, image_variants
38
+
39
+
40
+ class ImageSliderPlusData(GradioRootModel):
41
+ root: Union[
42
+ Tuple[FileData | None, FileData | None, JsonData | None],
43
+ Tuple[FileData | None, FileData | None],
44
+ None,
45
+ ]
46
+
47
+
48
+ class ImageSliderPlus(ImageSlider):
49
+ data_model = ImageSliderPlusData
50
+
51
+ def as_example(self, value):
52
+ return self.process_example_dims(value, 256, True)
53
+
54
+ def _format_image(self, im: Image):
55
+ if self.type != "filepath":
56
+ raise ValueError("ImageSliderPlus can be only created with type='filepath'")
57
+ if im is None:
58
+ return im
59
+ format = "png" #if im.mode == "I;16" else "webp"
60
+ path = processing_utils.save_pil_to_cache(
61
+ im, cache_dir=self.GRADIO_CACHE, format=format
62
+ )
63
+ self.temp_files.add(path)
64
+ return path
65
+
66
+ def _postprocess_image(self, y: image_variants):
67
+ if isinstance(y, np.ndarray):
68
+ format = "png" #if y.dtype == np.uint16 and y.squeeze().ndim == 2 else "webp"
69
+ path = processing_utils.save_img_array_to_cache(
70
+ y, cache_dir=self.GRADIO_CACHE, format=format
71
+ )
72
+ elif isinstance(y, Image.Image):
73
+ format = "png" #if y.mode == "I;16" else "webp"
74
+ path = processing_utils.save_pil_to_cache(
75
+ y, cache_dir=self.GRADIO_CACHE, format=format
76
+ )
77
+ elif isinstance(y, (str, Path)):
78
+ path = y if isinstance(y, str) else str(utils.abspath(y))
79
+ else:
80
+ raise ValueError("Cannot process this value as an Image")
81
+
82
+ return path
83
+
84
+ def postprocess(
85
+ self,
86
+ y: image_tuple,
87
+ ) -> ImageSliderPlusData:
88
+ if y is None:
89
+ return ImageSliderPlusData(root=(None, None, None))
90
+
91
+ settings = None
92
+ if type(y[0]) is str:
93
+ settings_candidate_path = y[0] + ".settings.json"
94
+ if os.path.isfile(settings_candidate_path):
95
+ with open(settings_candidate_path, "r") as fp:
96
+ settings = json.load(fp)
97
+
98
+ return ImageSliderPlusData(
99
+ root=(
100
+ FileData(path=self._postprocess_image(y[0])),
101
+ FileData(path=self._postprocess_image(y[1])),
102
+ JsonData(settings),
103
+ ),
104
+ )
105
+
106
+ def preprocess(self, x: ImageSliderPlusData) -> image_tuple:
107
+ if x is None:
108
+ return x
109
+
110
+ out_0 = self._preprocess_image(x.root[0])
111
+ out_1 = self._preprocess_image(x.root[1])
112
+
113
+ if len(x.root) > 2 and x.root[2] is not None:
114
+ with open(out_0 + ".settings.json", "w") as fp:
115
+ json.dump(x.root[2].root, fp)
116
+
117
+ return out_0, out_1
118
+
119
+ @staticmethod
120
+ def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
121
+ img = Image.open(image_path).convert("RGB")
122
+ if square:
123
+ width, height = img.size
124
+ min_side = min(width, height)
125
+ left = (width - min_side) // 2
126
+ top = (height - min_side) // 2
127
+ right = left + min_side
128
+ bottom = top + min_side
129
+ img = img.crop((left, top, right, bottom))
130
+ img.thumbnail((max_dim, max_dim))
131
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
132
+ img.save(temp_file.name, "PNG")
133
+ return temp_file.name
134
+
135
+ def process_example_dims(
136
+ self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
137
+ ) -> image_tuple:
138
+ if input_data is None:
139
+ return None
140
+ input_data = (str(input_data[0]), str(input_data[1]))
141
+ if self.proxy_url or client_utils.is_http_url_like(input_data[0]):
142
+ return input_data[0]
143
+ if max_dim is not None:
144
+ input_data = (
145
+ self.resize_and_save(input_data[0], max_dim, square),
146
+ self.resize_and_save(input_data[1], max_dim, square),
147
+ )
148
+ return (
149
+ self.move_resource_to_block_cache(input_data[0]),
150
+ self.move_resource_to_block_cache(input_data[1]),
151
+ )
152
+
153
+ def process_example(
154
+ self, input_data: tuple[str | Path | None] | None
155
+ ) -> image_tuple:
156
+ return self.process_example_dims(input_data)
gradio_dualvision/gradio_patches/imagesliderplus.pyi ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import json
25
+ import os.path
26
+ import tempfile
27
+ from pathlib import Path
28
+ from typing import Union, Tuple, Optional
29
+
30
+ import numpy as np
31
+ from PIL import Image
32
+ from gradio import processing_utils
33
+ from gradio import utils
34
+ from gradio.data_classes import FileData, GradioRootModel, JsonData
35
+ from gradio_client import utils as client_utils
36
+ from gradio_imageslider import ImageSlider
37
+ from gradio_imageslider.imageslider import image_tuple, image_variants
38
+
39
+
40
+ class ImageSliderPlusData(GradioRootModel):
41
+ root: Union[
42
+ Tuple[FileData | None, FileData | None, JsonData | None],
43
+ Tuple[FileData | None, FileData | None],
44
+ None,
45
+ ]
46
+
47
+ from gradio.events import Dependency
48
+
49
+ class ImageSliderPlus(ImageSlider):
50
+ data_model = ImageSliderPlusData
51
+
52
+ def as_example(self, value):
53
+ return self.process_example_dims(value, 256, True)
54
+
55
+ def _format_image(self, im: Image):
56
+ if self.type != "filepath":
57
+ raise ValueError("ImageSliderPlus can be only created with type='filepath'")
58
+ if im is None:
59
+ return im
60
+ format = "png" #if im.mode == "I;16" else "webp"
61
+ path = processing_utils.save_pil_to_cache(
62
+ im, cache_dir=self.GRADIO_CACHE, format=format
63
+ )
64
+ self.temp_files.add(path)
65
+ return path
66
+
67
+ def _postprocess_image(self, y: image_variants):
68
+ if isinstance(y, np.ndarray):
69
+ format = "png" #if y.dtype == np.uint16 and y.squeeze().ndim == 2 else "webp"
70
+ path = processing_utils.save_img_array_to_cache(
71
+ y, cache_dir=self.GRADIO_CACHE, format=format
72
+ )
73
+ elif isinstance(y, Image.Image):
74
+ format = "png" #if y.mode == "I;16" else "webp"
75
+ path = processing_utils.save_pil_to_cache(
76
+ y, cache_dir=self.GRADIO_CACHE, format=format
77
+ )
78
+ elif isinstance(y, (str, Path)):
79
+ path = y if isinstance(y, str) else str(utils.abspath(y))
80
+ else:
81
+ raise ValueError("Cannot process this value as an Image")
82
+
83
+ return path
84
+
85
+ def postprocess(
86
+ self,
87
+ y: image_tuple,
88
+ ) -> ImageSliderPlusData:
89
+ if y is None:
90
+ return ImageSliderPlusData(root=(None, None, None))
91
+
92
+ settings = None
93
+ if type(y[0]) is str:
94
+ settings_candidate_path = y[0] + ".settings.json"
95
+ if os.path.isfile(settings_candidate_path):
96
+ with open(settings_candidate_path, "r") as fp:
97
+ settings = json.load(fp)
98
+
99
+ return ImageSliderPlusData(
100
+ root=(
101
+ FileData(path=self._postprocess_image(y[0])),
102
+ FileData(path=self._postprocess_image(y[1])),
103
+ JsonData(settings),
104
+ ),
105
+ )
106
+
107
+ def preprocess(self, x: ImageSliderPlusData) -> image_tuple:
108
+ if x is None:
109
+ return x
110
+
111
+ out_0 = self._preprocess_image(x.root[0])
112
+ out_1 = self._preprocess_image(x.root[1])
113
+
114
+ if len(x.root) > 2 and x.root[2] is not None:
115
+ with open(out_0 + ".settings.json", "w") as fp:
116
+ json.dump(x.root[2].root, fp)
117
+
118
+ return out_0, out_1
119
+
120
+ @staticmethod
121
+ def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
122
+ img = Image.open(image_path).convert("RGB")
123
+ if square:
124
+ width, height = img.size
125
+ min_side = min(width, height)
126
+ left = (width - min_side) // 2
127
+ top = (height - min_side) // 2
128
+ right = left + min_side
129
+ bottom = top + min_side
130
+ img = img.crop((left, top, right, bottom))
131
+ img.thumbnail((max_dim, max_dim))
132
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
133
+ img.save(temp_file.name, "PNG")
134
+ return temp_file.name
135
+
136
+ def process_example_dims(
137
+ self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
138
+ ) -> image_tuple:
139
+ if input_data is None:
140
+ return None
141
+ input_data = (str(input_data[0]), str(input_data[1]))
142
+ if self.proxy_url or client_utils.is_http_url_like(input_data[0]):
143
+ return input_data[0]
144
+ if max_dim is not None:
145
+ input_data = (
146
+ self.resize_and_save(input_data[0], max_dim, square),
147
+ self.resize_and_save(input_data[1], max_dim, square),
148
+ )
149
+ return (
150
+ self.move_resource_to_block_cache(input_data[0]),
151
+ self.move_resource_to_block_cache(input_data[1]),
152
+ )
153
+
154
+ def process_example(
155
+ self, input_data: tuple[str | Path | None] | None
156
+ ) -> image_tuple:
157
+ return self.process_example_dims(input_data)
158
+ from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
159
+ from gradio.blocks import Block
160
+ if TYPE_CHECKING:
161
+ from gradio.components import Timer
gradio_dualvision/gradio_patches/radio.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import gradio
25
+ from gradio import components
26
+ from gradio.components.base import Component
27
+ from gradio.data_classes import (
28
+ GradioModel,
29
+ GradioRootModel,
30
+ )
31
+
32
+ from gradio.blocks import BlockContext
33
+
34
+
35
+ def patched_postprocess_update_dict(
36
+ block: Component | BlockContext, update_dict: dict, postprocess: bool = True
37
+ ):
38
+ """
39
+ This is a patched version of the original function where 'pop' is replaced with 'get' in the first line.
40
+ The key will no longer be removed but can still be accessed safely.
41
+ This fixed gradio.Radio component persisting the value selection through gradio.Examples.
42
+ """
43
+ value = update_dict.get("value", components._Keywords.NO_VALUE)
44
+
45
+ # Continue with the original logic
46
+ update_dict = {k: getattr(block, k) for k in update_dict if hasattr(block, k)}
47
+ if value is not components._Keywords.NO_VALUE:
48
+ if postprocess:
49
+ update_dict["value"] = block.postprocess(value)
50
+ if isinstance(update_dict["value"], (GradioModel, GradioRootModel)):
51
+ update_dict["value"] = update_dict["value"].model_dump()
52
+ else:
53
+ update_dict["value"] = value
54
+ update_dict["__type__"] = "update"
55
+ return update_dict
56
+
57
+
58
+ gradio.blocks.postprocess_update_dict = patched_postprocess_update_dict
59
+
60
+
61
+ class Radio(gradio.Radio):
62
+ pass
gradio_dualvision/gradio_patches/radio.pyi ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import gradio
25
+ from gradio import components
26
+ from gradio.components.base import Component
27
+ from gradio.data_classes import (
28
+ GradioModel,
29
+ GradioRootModel,
30
+ )
31
+
32
+ from gradio.blocks import BlockContext
33
+
34
+
35
+ def patched_postprocess_update_dict(
36
+ block: Component | BlockContext, update_dict: dict, postprocess: bool = True
37
+ ):
38
+ """
39
+ This is a patched version of the original function where 'pop' is replaced with 'get' in the first line.
40
+ The key will no longer be removed but can still be accessed safely.
41
+ This fixed gradio.Radio component persisting the value selection through gradio.Examples.
42
+ """
43
+ value = update_dict.get("value", components._Keywords.NO_VALUE)
44
+
45
+ # Continue with the original logic
46
+ update_dict = {k: getattr(block, k) for k in update_dict if hasattr(block, k)}
47
+ if value is not components._Keywords.NO_VALUE:
48
+ if postprocess:
49
+ update_dict["value"] = block.postprocess(value)
50
+ if isinstance(update_dict["value"], (GradioModel, GradioRootModel)):
51
+ update_dict["value"] = update_dict["value"].model_dump()
52
+ else:
53
+ update_dict["value"] = value
54
+ update_dict["__type__"] = "update"
55
+ return update_dict
56
+
57
+
58
+ gradio.blocks.postprocess_update_dict = patched_postprocess_update_dict
59
+
60
+ from gradio.events import Dependency
61
+
62
+ class Radio(gradio.Radio):
63
+ pass
64
+ from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
65
+ from gradio.blocks import Block
66
+ if TYPE_CHECKING:
67
+ from gradio.components import Timer
gradio_dualvision/gradio_patches/templates/component/__vite-browser-external-2447137e.js ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ const e = {};
2
+ export {
3
+ e as default
4
+ };
gradio_dualvision/gradio_patches/templates/component/index.js ADDED
The diff for this file is too large to render. See raw diff
 
gradio_dualvision/gradio_patches/templates/component/style.css ADDED
@@ -0,0 +1 @@
 
 
1
+ .wrap.svelte-1w37x6c.svelte-1w37x6c{position:relative;width:100%;height:100%;z-index:100}.icon-wrap.svelte-1w37x6c.svelte-1w37x6c{position:absolute;top:50%;transform:translateY(-50%);left:-40px;width:32px;transition:.2s;color:var(--body-text-color)}.icon-wrap.right.svelte-1w37x6c.svelte-1w37x6c{left:60px;transform:translateY(-50%) translate(-100%) rotate(180deg)}.icon-wrap.active.svelte-1w37x6c.svelte-1w37x6c,.icon-wrap.disabled.svelte-1w37x6c.svelte-1w37x6c{opacity:0}.outer.svelte-1w37x6c.svelte-1w37x6c{width:20px;height:100%;cursor:grab;position:absolute;top:0;left:0}.inner.svelte-1w37x6c.svelte-1w37x6c{box-shadow:-1px 0 6px 1px #0003;width:1px;height:100%;background:var(--color);position:absolute;left:calc((100% - 2px)/2)}.disabled.svelte-1w37x6c.svelte-1w37x6c{cursor:auto}.disabled.svelte-1w37x6c .inner.svelte-1w37x6c{box-shadow:none}.block.svelte-1t38q2d{position:relative;margin:0;box-shadow:var(--block-shadow);border-width:var(--block-border-width);border-color:var(--block-border-color);border-radius:var(--block-radius);background:var(--block-background-fill);width:100%;line-height:var(--line-sm)}.block.border_focus.svelte-1t38q2d{border-color:var(--color-accent)}.padded.svelte-1t38q2d{padding:var(--block-padding)}.hidden.svelte-1t38q2d{display:none}.hide-container.svelte-1t38q2d{margin:0;box-shadow:none;--block-border-width:0;background:transparent;padding:0;overflow:visible}div.svelte-1hnfib2{margin-bottom:var(--spacing-lg);color:var(--block-info-text-color);font-weight:var(--block-info-text-weight);font-size:var(--block-info-text-size);line-height:var(--line-sm)}span.has-info.svelte-22c38v{margin-bottom:var(--spacing-xs)}span.svelte-22c38v:not(.has-info){margin-bottom:var(--spacing-lg)}span.svelte-22c38v{display:inline-block;position:relative;z-index:var(--layer-4);border:solid var(--block-title-border-width) var(--block-title-border-color);border-radius:var(--block-title-radius);background:var(--block-title-background-fill);padding:var(--block-title-padding);color:var(--block-title-text-color);font-weight:var(--block-title-text-weight);font-size:var(--block-title-text-size);line-height:var(--line-sm)}.hide.svelte-22c38v{margin:0;height:0}label.svelte-9gxdi0{display:inline-flex;align-items:center;z-index:var(--layer-2);box-shadow:var(--block-label-shadow);border:var(--block-label-border-width) solid var(--border-color-primary);border-top:none;border-left:none;border-radius:var(--block-label-radius);background:var(--block-label-background-fill);padding:var(--block-label-padding);pointer-events:none;color:var(--block-label-text-color);font-weight:var(--block-label-text-weight);font-size:var(--block-label-text-size);line-height:var(--line-sm)}.gr-group label.svelte-9gxdi0{border-top-left-radius:0}label.float.svelte-9gxdi0{position:absolute;top:var(--block-label-margin);left:var(--block-label-margin)}label.svelte-9gxdi0:not(.float){position:static;margin-top:var(--block-label-margin);margin-left:var(--block-label-margin)}.hide.svelte-9gxdi0{height:0}span.svelte-9gxdi0{opacity:.8;margin-right:var(--size-2);width:calc(var(--block-label-text-size) - 1px);height:calc(var(--block-label-text-size) - 1px)}.hide-label.svelte-9gxdi0{box-shadow:none;border-width:0;background:transparent;overflow:visible}button.svelte-lpi64a{display:flex;justify-content:center;align-items:center;gap:1px;z-index:var(--layer-2);border-radius:var(--radius-sm);color:var(--block-label-text-color);border:1px solid transparent}button[disabled].svelte-lpi64a{opacity:.5;box-shadow:none}button[disabled].svelte-lpi64a:hover{cursor:not-allowed}.padded.svelte-lpi64a{padding:2px;background:var(--bg-color);box-shadow:var(--shadow-drop);border:1px solid var(--button-secondary-border-color)}button.svelte-lpi64a:hover,button.highlight.svelte-lpi64a{cursor:pointer;color:var(--color-accent)}.padded.svelte-lpi64a:hover{border:2px solid var(--button-secondary-border-color-hover);padding:1px;color:var(--block-label-text-color)}span.svelte-lpi64a{padding:0 1px;font-size:10px}div.svelte-lpi64a{padding:2px;display:flex;align-items:flex-end}.small.svelte-lpi64a{width:14px;height:14px}.large.svelte-lpi64a{width:22px;height:22px}.pending.svelte-lpi64a{animation:svelte-lpi64a-flash .5s infinite}@keyframes svelte-lpi64a-flash{0%{opacity:.5}50%{opacity:1}to{opacity:.5}}.transparent.svelte-lpi64a{background:transparent;border:none;box-shadow:none}.empty.svelte-3w3rth{display:flex;justify-content:center;align-items:center;margin-top:calc(0px - var(--size-6));height:var(--size-full)}.icon.svelte-3w3rth{opacity:.5;height:var(--size-5);color:var(--body-text-color)}.small.svelte-3w3rth{min-height:calc(var(--size-32) - 20px)}.large.svelte-3w3rth{min-height:calc(var(--size-64) - 20px)}.unpadded_box.svelte-3w3rth{margin-top:0}.small_parent.svelte-3w3rth{min-height:100%!important}.dropdown-arrow.svelte-145leq6{fill:currentColor}.wrap.svelte-kzcjhc{display:flex;flex-direction:column;justify-content:center;align-items:center;min-height:var(--size-60);color:var(--block-label-text-color);line-height:var(--line-md);height:100%;padding-top:var(--size-3)}.or.svelte-kzcjhc{color:var(--body-text-color-subdued);display:flex}.icon-wrap.svelte-kzcjhc{width:30px;margin-bottom:var(--spacing-lg)}@media (--screen-md){.wrap.svelte-kzcjhc{font-size:var(--text-lg)}}.hovered.svelte-kzcjhc{color:var(--color-accent)}div.svelte-ipfyu7{border-top:1px solid transparent;display:flex;max-height:100%;justify-content:center;gap:var(--spacing-sm);height:auto;align-items:flex-end;padding-bottom:var(--spacing-xl);color:var(--block-label-text-color);flex-shrink:0;width:95%}.show_border.svelte-ipfyu7{border-top:1px solid var(--block-border-color);margin-top:var(--spacing-xxl);box-shadow:var(--shadow-drop)}.source-selection.svelte-lde7lt{display:flex;align-items:center;justify-content:center;border-top:1px solid var(--border-color-primary);width:95%;bottom:0;left:0;right:0;margin-left:auto;margin-right:auto;align-self:flex-end}.icon.svelte-lde7lt{width:22px;height:22px;margin:var(--spacing-lg) var(--spacing-xs);padding:var(--spacing-xs);color:var(--neutral-400);border-radius:var(--radius-md)}.selected.svelte-lde7lt{color:var(--color-accent)}.icon.svelte-lde7lt:hover,.icon.svelte-lde7lt:focus{color:var(--color-accent)}div.svelte-1g74h68{display:flex;position:absolute;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-5)}.wrap.svelte-1juivz4.svelte-1juivz4{overflow-y:auto;transition:opacity .5s ease-in-out;background:var(--block-background-fill);position:relative;display:flex;flex-direction:column;align-items:center;justify-content:center;min-height:var(--size-40);width:100%}.wrap.svelte-1juivz4.svelte-1juivz4:after{content:"";position:absolute;top:0;left:0;width:var(--upload-progress-width);height:100%;transition:all .5s ease-in-out;z-index:1}.uploading.svelte-1juivz4.svelte-1juivz4{font-size:var(--text-lg);font-family:var(--font);z-index:2}.file-name.svelte-1juivz4.svelte-1juivz4{margin:var(--spacing-md);font-size:var(--text-lg);color:var(--body-text-color-subdued)}.file.svelte-1juivz4.svelte-1juivz4{font-size:var(--text-md);z-index:2;display:flex;align-items:center}.file.svelte-1juivz4 progress.svelte-1juivz4{display:inline;height:var(--size-1);width:100%;transition:all .5s ease-in-out;color:var(--color-accent);border:none}.file.svelte-1juivz4 progress[value].svelte-1juivz4::-webkit-progress-value{background-color:var(--color-accent);border-radius:20px}.file.svelte-1juivz4 progress[value].svelte-1juivz4::-webkit-progress-bar{background-color:var(--border-color-accent);border-radius:20px}.progress-bar.svelte-1juivz4.svelte-1juivz4{width:14px;height:14px;border-radius:50%;background:radial-gradient(closest-side,var(--block-background-fill) 64%,transparent 53% 100%),conic-gradient(var(--color-accent) var(--upload-progress-width),var(--border-color-accent) 0);transition:all .5s ease-in-out}button.svelte-1aq8tno{cursor:pointer;width:var(--size-full)}.hidden.svelte-1aq8tno{display:none;height:0!important;position:absolute;width:0;flex-grow:0}.center.svelte-1aq8tno{display:flex;justify-content:center}.flex.svelte-1aq8tno{display:flex;justify-content:center;align-items:center}input.svelte-1aq8tno{display:none}div.svelte-1wj0ocy{display:flex;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-1)}.not-absolute.svelte-1wj0ocy{margin:var(--size-1)}.upload-wrap.svelte-106mu0e.svelte-106mu0e{display:flex;justify-content:center;align-items:center;height:100%;width:100%}.wrap.svelte-106mu0e.svelte-106mu0e{width:100%}.half-wrap.svelte-106mu0e.svelte-106mu0e{width:50%}.image-container.svelte-106mu0e.svelte-106mu0e,img.svelte-106mu0e.svelte-106mu0e,.empty-wrap.svelte-106mu0e.svelte-106mu0e{width:var(--size-full);height:var(--size-full)}img.svelte-106mu0e.svelte-106mu0e{object-fit:cover}.fixed.svelte-106mu0e.svelte-106mu0e{--anim-block-background-fill:255, 255, 255;position:absolute;top:0;left:0;background-color:rgba(var(--anim-block-background-fill),.8);z-index:0}@media (prefers-color-scheme: dark){.fixed.svelte-106mu0e.svelte-106mu0e{--anim-block-background-fill:31, 41, 55}}.side-by-side.svelte-106mu0e img.svelte-106mu0e{width:50%;object-fit:contain}.empty-wrap.svelte-106mu0e.svelte-106mu0e{pointer-events:none}.icon-buttons.svelte-106mu0e.svelte-106mu0e{display:flex;position:absolute;right:8px;z-index:var(--layer-top);top:8px}svg.svelte-43sxxs.svelte-43sxxs{width:var(--size-20);height:var(--size-20)}svg.svelte-43sxxs path.svelte-43sxxs{fill:var(--loader-color)}div.svelte-43sxxs.svelte-43sxxs{z-index:var(--layer-2)}.margin.svelte-43sxxs.svelte-43sxxs{margin:var(--size-4)}.wrap.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;justify-content:center;align-items:center;z-index:var(--layer-top);transition:opacity .1s ease-in-out;border-radius:var(--block-radius);background:var(--block-background-fill);padding:0 var(--size-6);max-height:var(--size-screen-h);overflow:hidden;pointer-events:none}.wrap.center.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;left:0}.wrap.default.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;bottom:0;left:0}.hide.svelte-1txqlrd.svelte-1txqlrd{opacity:0;pointer-events:none}.generating.svelte-1txqlrd.svelte-1txqlrd{animation:svelte-1txqlrd-pulse 2s cubic-bezier(.4,0,.6,1) infinite;border:2px solid var(--color-accent);background:transparent}.translucent.svelte-1txqlrd.svelte-1txqlrd{background:none}@keyframes svelte-1txqlrd-pulse{0%,to{opacity:1}50%{opacity:.5}}.loading.svelte-1txqlrd.svelte-1txqlrd{z-index:var(--layer-2);color:var(--body-text-color)}.eta-bar.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;bottom:0;left:0;transform-origin:left;opacity:.8;z-index:var(--layer-1);transition:10ms;background:var(--background-fill-secondary)}.progress-bar-wrap.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary);background:var(--background-fill-primary);width:55.5%;height:var(--size-4)}.progress-bar.svelte-1txqlrd.svelte-1txqlrd{transform-origin:left;background-color:var(--loader-color);width:var(--size-full);height:var(--size-full)}.progress-level.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;align-items:center;gap:1;z-index:var(--layer-2);width:var(--size-full)}.progress-level-inner.svelte-1txqlrd.svelte-1txqlrd{margin:var(--size-2) auto;color:var(--body-text-color);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text-center.svelte-1txqlrd.svelte-1txqlrd{display:flex;position:absolute;top:0;right:0;justify-content:center;align-items:center;transform:translateY(var(--size-6));z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono);text-align:center}.error.svelte-1txqlrd.svelte-1txqlrd{box-shadow:var(--shadow-drop);border:solid 1px var(--error-border-color);border-radius:var(--radius-full);background:var(--error-background-fill);padding-right:var(--size-4);padding-left:var(--size-4);color:var(--error-text-color);font-weight:var(--weight-semibold);font-size:var(--text-lg);line-height:var(--line-lg);font-family:var(--font)}.minimal.svelte-1txqlrd .progress-text.svelte-1txqlrd{background:var(--block-background-fill)}.border.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary)}.toast-body.svelte-solcu7{display:flex;position:relative;right:0;left:0;align-items:center;margin:var(--size-6) var(--size-4);margin:auto;border-radius:var(--container-radius);overflow:hidden;pointer-events:auto}.toast-body.error.svelte-solcu7{border:1px solid var(--color-red-700);background:var(--color-red-50)}.dark .toast-body.error.svelte-solcu7{border:1px solid var(--color-red-500);background-color:var(--color-grey-950)}.toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-700);background:var(--color-yellow-50)}.dark .toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-500);background-color:var(--color-grey-950)}.toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-700);background:var(--color-grey-50)}.dark .toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-500);background-color:var(--color-grey-950)}.toast-title.svelte-solcu7{display:flex;align-items:center;font-weight:var(--weight-bold);font-size:var(--text-lg);line-height:var(--line-sm);text-transform:capitalize}.toast-title.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-title.error.svelte-solcu7{color:var(--color-red-50)}.toast-title.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-title.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-title.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-title.info.svelte-solcu7{color:var(--color-grey-50)}.toast-close.svelte-solcu7{margin:0 var(--size-3);border-radius:var(--size-3);padding:0px var(--size-1-5);font-size:var(--size-5);line-height:var(--size-5)}.toast-close.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-close.error.svelte-solcu7{color:var(--color-red-500)}.toast-close.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-close.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-close.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-close.info.svelte-solcu7{color:var(--color-grey-500)}.toast-text.svelte-solcu7{font-size:var(--text-lg)}.toast-text.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-text.error.svelte-solcu7{color:var(--color-red-50)}.toast-text.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-text.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-text.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-text.info.svelte-solcu7{color:var(--color-grey-50)}.toast-details.svelte-solcu7{margin:var(--size-3) var(--size-3) var(--size-3) 0;width:100%}.toast-icon.svelte-solcu7{display:flex;position:absolute;position:relative;flex-shrink:0;justify-content:center;align-items:center;margin:var(--size-2);border-radius:var(--radius-full);padding:var(--size-1);padding-left:calc(var(--size-1) - 1px);width:35px;height:35px}.toast-icon.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-icon.error.svelte-solcu7{color:var(--color-red-500)}.toast-icon.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-icon.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-icon.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-icon.info.svelte-solcu7{color:var(--color-grey-500)}@keyframes svelte-solcu7-countdown{0%{transform:scaleX(1)}to{transform:scaleX(0)}}.timer.svelte-solcu7{position:absolute;bottom:0;left:0;transform-origin:0 0;animation:svelte-solcu7-countdown 10s linear forwards;width:100%;height:var(--size-1)}.timer.error.svelte-solcu7{background:var(--color-red-700)}.dark .timer.error.svelte-solcu7{background:var(--color-red-500)}.timer.warning.svelte-solcu7{background:var(--color-yellow-700)}.dark .timer.warning.svelte-solcu7{background:var(--color-yellow-500)}.timer.info.svelte-solcu7{background:var(--color-grey-700)}.dark .timer.info.svelte-solcu7{background:var(--color-grey-500)}.toast-wrap.svelte-gatr8h{display:flex;position:fixed;top:var(--size-4);right:var(--size-4);flex-direction:column;align-items:end;gap:var(--size-2);z-index:var(--layer-top);width:calc(100% - var(--size-8))}@media (--screen-sm){.toast-wrap.svelte-gatr8h{width:calc(var(--size-96) + var(--size-10))}}.slider-wrap.svelte-a2zf8o{-webkit-user-select:none;user-select:none;max-height:calc(100vh - 40px)}img.svelte-a2zf8o{width:var(--size-full);height:var(--size-full);object-fit:cover}.fixed.svelte-a2zf8o{position:absolute;top:0;left:0}.hidden.svelte-a2zf8o{opacity:0}.icon-buttons.svelte-a2zf8o{display:flex;position:absolute;right:8px;z-index:var(--layer-top);top:8px}.status-wrap.svelte-6wvohu{position:absolute;height:100%;width:100%;--anim-block-background-fill:255, 255, 255;z-index:1;pointer-events:none}@media (prefers-color-scheme: dark){.status-wrap.svelte-6wvohu{--anim-block-background-fill:31, 41, 55}}@keyframes svelte-6wvohu-pulse{0%{background-color:rgba(var(--anim-block-background-fill),.7)}50%{background-color:rgba(var(--anim-block-background-fill),.4)}to{background-color:rgba(var(--anim-block-background-fill),.7)}}.status-wrap.half.svelte-6wvohu .wrap{border-radius:0;animation:svelte-6wvohu-pulse 1.4s infinite ease-in-out}.status-wrap.half.svelte-6wvohu .progress-text{background:none!important}.status-wrap.half.svelte-6wvohu .eta-bar{opacity:0}.icon-buttons.svelte-6wvohu{display:flex;position:absolute;right:6px;z-index:var(--layer-1);top:6px}
gradio_dualvision/gradio_patches/templates/component/wrapper-6f348d45-19fa94bf.js ADDED
@@ -0,0 +1,2453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import S from "./__vite-browser-external-2447137e.js";
2
+ function z(s) {
3
+ return s && s.__esModule && Object.prototype.hasOwnProperty.call(s, "default") ? s.default : s;
4
+ }
5
+ function gt(s) {
6
+ if (s.__esModule)
7
+ return s;
8
+ var e = s.default;
9
+ if (typeof e == "function") {
10
+ var t = function r() {
11
+ if (this instanceof r) {
12
+ var i = [null];
13
+ i.push.apply(i, arguments);
14
+ var n = Function.bind.apply(e, i);
15
+ return new n();
16
+ }
17
+ return e.apply(this, arguments);
18
+ };
19
+ t.prototype = e.prototype;
20
+ } else
21
+ t = {};
22
+ return Object.defineProperty(t, "__esModule", { value: !0 }), Object.keys(s).forEach(function(r) {
23
+ var i = Object.getOwnPropertyDescriptor(s, r);
24
+ Object.defineProperty(t, r, i.get ? i : {
25
+ enumerable: !0,
26
+ get: function() {
27
+ return s[r];
28
+ }
29
+ });
30
+ }), t;
31
+ }
32
+ const { Duplex: yt } = S;
33
+ function Oe(s) {
34
+ s.emit("close");
35
+ }
36
+ function vt() {
37
+ !this.destroyed && this._writableState.finished && this.destroy();
38
+ }
39
+ function Qe(s) {
40
+ this.removeListener("error", Qe), this.destroy(), this.listenerCount("error") === 0 && this.emit("error", s);
41
+ }
42
+ function St(s, e) {
43
+ let t = !0;
44
+ const r = new yt({
45
+ ...e,
46
+ autoDestroy: !1,
47
+ emitClose: !1,
48
+ objectMode: !1,
49
+ writableObjectMode: !1
50
+ });
51
+ return s.on("message", function(n, o) {
52
+ const l = !o && r._readableState.objectMode ? n.toString() : n;
53
+ r.push(l) || s.pause();
54
+ }), s.once("error", function(n) {
55
+ r.destroyed || (t = !1, r.destroy(n));
56
+ }), s.once("close", function() {
57
+ r.destroyed || r.push(null);
58
+ }), r._destroy = function(i, n) {
59
+ if (s.readyState === s.CLOSED) {
60
+ n(i), process.nextTick(Oe, r);
61
+ return;
62
+ }
63
+ let o = !1;
64
+ s.once("error", function(f) {
65
+ o = !0, n(f);
66
+ }), s.once("close", function() {
67
+ o || n(i), process.nextTick(Oe, r);
68
+ }), t && s.terminate();
69
+ }, r._final = function(i) {
70
+ if (s.readyState === s.CONNECTING) {
71
+ s.once("open", function() {
72
+ r._final(i);
73
+ });
74
+ return;
75
+ }
76
+ s._socket !== null && (s._socket._writableState.finished ? (i(), r._readableState.endEmitted && r.destroy()) : (s._socket.once("finish", function() {
77
+ i();
78
+ }), s.close()));
79
+ }, r._read = function() {
80
+ s.isPaused && s.resume();
81
+ }, r._write = function(i, n, o) {
82
+ if (s.readyState === s.CONNECTING) {
83
+ s.once("open", function() {
84
+ r._write(i, n, o);
85
+ });
86
+ return;
87
+ }
88
+ s.send(i, o);
89
+ }, r.on("end", vt), r.on("error", Qe), r;
90
+ }
91
+ var Et = St;
92
+ const Vs = /* @__PURE__ */ z(Et);
93
+ var te = { exports: {} }, U = {
94
+ BINARY_TYPES: ["nodebuffer", "arraybuffer", "fragments"],
95
+ EMPTY_BUFFER: Buffer.alloc(0),
96
+ GUID: "258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
97
+ kForOnEventAttribute: Symbol("kIsForOnEventAttribute"),
98
+ kListener: Symbol("kListener"),
99
+ kStatusCode: Symbol("status-code"),
100
+ kWebSocket: Symbol("websocket"),
101
+ NOOP: () => {
102
+ }
103
+ }, bt, xt;
104
+ const { EMPTY_BUFFER: kt } = U, Se = Buffer[Symbol.species];
105
+ function wt(s, e) {
106
+ if (s.length === 0)
107
+ return kt;
108
+ if (s.length === 1)
109
+ return s[0];
110
+ const t = Buffer.allocUnsafe(e);
111
+ let r = 0;
112
+ for (let i = 0; i < s.length; i++) {
113
+ const n = s[i];
114
+ t.set(n, r), r += n.length;
115
+ }
116
+ return r < e ? new Se(t.buffer, t.byteOffset, r) : t;
117
+ }
118
+ function Je(s, e, t, r, i) {
119
+ for (let n = 0; n < i; n++)
120
+ t[r + n] = s[n] ^ e[n & 3];
121
+ }
122
+ function et(s, e) {
123
+ for (let t = 0; t < s.length; t++)
124
+ s[t] ^= e[t & 3];
125
+ }
126
+ function Ot(s) {
127
+ return s.length === s.buffer.byteLength ? s.buffer : s.buffer.slice(s.byteOffset, s.byteOffset + s.length);
128
+ }
129
+ function Ee(s) {
130
+ if (Ee.readOnly = !0, Buffer.isBuffer(s))
131
+ return s;
132
+ let e;
133
+ return s instanceof ArrayBuffer ? e = new Se(s) : ArrayBuffer.isView(s) ? e = new Se(s.buffer, s.byteOffset, s.byteLength) : (e = Buffer.from(s), Ee.readOnly = !1), e;
134
+ }
135
+ te.exports = {
136
+ concat: wt,
137
+ mask: Je,
138
+ toArrayBuffer: Ot,
139
+ toBuffer: Ee,
140
+ unmask: et
141
+ };
142
+ if (!process.env.WS_NO_BUFFER_UTIL)
143
+ try {
144
+ const s = require("bufferutil");
145
+ xt = te.exports.mask = function(e, t, r, i, n) {
146
+ n < 48 ? Je(e, t, r, i, n) : s.mask(e, t, r, i, n);
147
+ }, bt = te.exports.unmask = function(e, t) {
148
+ e.length < 32 ? et(e, t) : s.unmask(e, t);
149
+ };
150
+ } catch {
151
+ }
152
+ var ne = te.exports;
153
+ const Ce = Symbol("kDone"), ue = Symbol("kRun");
154
+ let Ct = class {
155
+ /**
156
+ * Creates a new `Limiter`.
157
+ *
158
+ * @param {Number} [concurrency=Infinity] The maximum number of jobs allowed
159
+ * to run concurrently
160
+ */
161
+ constructor(e) {
162
+ this[Ce] = () => {
163
+ this.pending--, this[ue]();
164
+ }, this.concurrency = e || 1 / 0, this.jobs = [], this.pending = 0;
165
+ }
166
+ /**
167
+ * Adds a job to the queue.
168
+ *
169
+ * @param {Function} job The job to run
170
+ * @public
171
+ */
172
+ add(e) {
173
+ this.jobs.push(e), this[ue]();
174
+ }
175
+ /**
176
+ * Removes a job from the queue and runs it if possible.
177
+ *
178
+ * @private
179
+ */
180
+ [ue]() {
181
+ if (this.pending !== this.concurrency && this.jobs.length) {
182
+ const e = this.jobs.shift();
183
+ this.pending++, e(this[Ce]);
184
+ }
185
+ }
186
+ };
187
+ var Tt = Ct;
188
+ const W = S, Te = ne, Lt = Tt, { kStatusCode: tt } = U, Nt = Buffer[Symbol.species], Pt = Buffer.from([0, 0, 255, 255]), se = Symbol("permessage-deflate"), w = Symbol("total-length"), V = Symbol("callback"), C = Symbol("buffers"), J = Symbol("error");
189
+ let K, Rt = class {
190
+ /**
191
+ * Creates a PerMessageDeflate instance.
192
+ *
193
+ * @param {Object} [options] Configuration options
194
+ * @param {(Boolean|Number)} [options.clientMaxWindowBits] Advertise support
195
+ * for, or request, a custom client window size
196
+ * @param {Boolean} [options.clientNoContextTakeover=false] Advertise/
197
+ * acknowledge disabling of client context takeover
198
+ * @param {Number} [options.concurrencyLimit=10] The number of concurrent
199
+ * calls to zlib
200
+ * @param {(Boolean|Number)} [options.serverMaxWindowBits] Request/confirm the
201
+ * use of a custom server window size
202
+ * @param {Boolean} [options.serverNoContextTakeover=false] Request/accept
203
+ * disabling of server context takeover
204
+ * @param {Number} [options.threshold=1024] Size (in bytes) below which
205
+ * messages should not be compressed if context takeover is disabled
206
+ * @param {Object} [options.zlibDeflateOptions] Options to pass to zlib on
207
+ * deflate
208
+ * @param {Object} [options.zlibInflateOptions] Options to pass to zlib on
209
+ * inflate
210
+ * @param {Boolean} [isServer=false] Create the instance in either server or
211
+ * client mode
212
+ * @param {Number} [maxPayload=0] The maximum allowed message length
213
+ */
214
+ constructor(e, t, r) {
215
+ if (this._maxPayload = r | 0, this._options = e || {}, this._threshold = this._options.threshold !== void 0 ? this._options.threshold : 1024, this._isServer = !!t, this._deflate = null, this._inflate = null, this.params = null, !K) {
216
+ const i = this._options.concurrencyLimit !== void 0 ? this._options.concurrencyLimit : 10;
217
+ K = new Lt(i);
218
+ }
219
+ }
220
+ /**
221
+ * @type {String}
222
+ */
223
+ static get extensionName() {
224
+ return "permessage-deflate";
225
+ }
226
+ /**
227
+ * Create an extension negotiation offer.
228
+ *
229
+ * @return {Object} Extension parameters
230
+ * @public
231
+ */
232
+ offer() {
233
+ const e = {};
234
+ return this._options.serverNoContextTakeover && (e.server_no_context_takeover = !0), this._options.clientNoContextTakeover && (e.client_no_context_takeover = !0), this._options.serverMaxWindowBits && (e.server_max_window_bits = this._options.serverMaxWindowBits), this._options.clientMaxWindowBits ? e.client_max_window_bits = this._options.clientMaxWindowBits : this._options.clientMaxWindowBits == null && (e.client_max_window_bits = !0), e;
235
+ }
236
+ /**
237
+ * Accept an extension negotiation offer/response.
238
+ *
239
+ * @param {Array} configurations The extension negotiation offers/reponse
240
+ * @return {Object} Accepted configuration
241
+ * @public
242
+ */
243
+ accept(e) {
244
+ return e = this.normalizeParams(e), this.params = this._isServer ? this.acceptAsServer(e) : this.acceptAsClient(e), this.params;
245
+ }
246
+ /**
247
+ * Releases all resources used by the extension.
248
+ *
249
+ * @public
250
+ */
251
+ cleanup() {
252
+ if (this._inflate && (this._inflate.close(), this._inflate = null), this._deflate) {
253
+ const e = this._deflate[V];
254
+ this._deflate.close(), this._deflate = null, e && e(
255
+ new Error(
256
+ "The deflate stream was closed while data was being processed"
257
+ )
258
+ );
259
+ }
260
+ }
261
+ /**
262
+ * Accept an extension negotiation offer.
263
+ *
264
+ * @param {Array} offers The extension negotiation offers
265
+ * @return {Object} Accepted configuration
266
+ * @private
267
+ */
268
+ acceptAsServer(e) {
269
+ const t = this._options, r = e.find((i) => !(t.serverNoContextTakeover === !1 && i.server_no_context_takeover || i.server_max_window_bits && (t.serverMaxWindowBits === !1 || typeof t.serverMaxWindowBits == "number" && t.serverMaxWindowBits > i.server_max_window_bits) || typeof t.clientMaxWindowBits == "number" && !i.client_max_window_bits));
270
+ if (!r)
271
+ throw new Error("None of the extension offers can be accepted");
272
+ return t.serverNoContextTakeover && (r.server_no_context_takeover = !0), t.clientNoContextTakeover && (r.client_no_context_takeover = !0), typeof t.serverMaxWindowBits == "number" && (r.server_max_window_bits = t.serverMaxWindowBits), typeof t.clientMaxWindowBits == "number" ? r.client_max_window_bits = t.clientMaxWindowBits : (r.client_max_window_bits === !0 || t.clientMaxWindowBits === !1) && delete r.client_max_window_bits, r;
273
+ }
274
+ /**
275
+ * Accept the extension negotiation response.
276
+ *
277
+ * @param {Array} response The extension negotiation response
278
+ * @return {Object} Accepted configuration
279
+ * @private
280
+ */
281
+ acceptAsClient(e) {
282
+ const t = e[0];
283
+ if (this._options.clientNoContextTakeover === !1 && t.client_no_context_takeover)
284
+ throw new Error('Unexpected parameter "client_no_context_takeover"');
285
+ if (!t.client_max_window_bits)
286
+ typeof this._options.clientMaxWindowBits == "number" && (t.client_max_window_bits = this._options.clientMaxWindowBits);
287
+ else if (this._options.clientMaxWindowBits === !1 || typeof this._options.clientMaxWindowBits == "number" && t.client_max_window_bits > this._options.clientMaxWindowBits)
288
+ throw new Error(
289
+ 'Unexpected or invalid parameter "client_max_window_bits"'
290
+ );
291
+ return t;
292
+ }
293
+ /**
294
+ * Normalize parameters.
295
+ *
296
+ * @param {Array} configurations The extension negotiation offers/reponse
297
+ * @return {Array} The offers/response with normalized parameters
298
+ * @private
299
+ */
300
+ normalizeParams(e) {
301
+ return e.forEach((t) => {
302
+ Object.keys(t).forEach((r) => {
303
+ let i = t[r];
304
+ if (i.length > 1)
305
+ throw new Error(`Parameter "${r}" must have only a single value`);
306
+ if (i = i[0], r === "client_max_window_bits") {
307
+ if (i !== !0) {
308
+ const n = +i;
309
+ if (!Number.isInteger(n) || n < 8 || n > 15)
310
+ throw new TypeError(
311
+ `Invalid value for parameter "${r}": ${i}`
312
+ );
313
+ i = n;
314
+ } else if (!this._isServer)
315
+ throw new TypeError(
316
+ `Invalid value for parameter "${r}": ${i}`
317
+ );
318
+ } else if (r === "server_max_window_bits") {
319
+ const n = +i;
320
+ if (!Number.isInteger(n) || n < 8 || n > 15)
321
+ throw new TypeError(
322
+ `Invalid value for parameter "${r}": ${i}`
323
+ );
324
+ i = n;
325
+ } else if (r === "client_no_context_takeover" || r === "server_no_context_takeover") {
326
+ if (i !== !0)
327
+ throw new TypeError(
328
+ `Invalid value for parameter "${r}": ${i}`
329
+ );
330
+ } else
331
+ throw new Error(`Unknown parameter "${r}"`);
332
+ t[r] = i;
333
+ });
334
+ }), e;
335
+ }
336
+ /**
337
+ * Decompress data. Concurrency limited.
338
+ *
339
+ * @param {Buffer} data Compressed data
340
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
341
+ * @param {Function} callback Callback
342
+ * @public
343
+ */
344
+ decompress(e, t, r) {
345
+ K.add((i) => {
346
+ this._decompress(e, t, (n, o) => {
347
+ i(), r(n, o);
348
+ });
349
+ });
350
+ }
351
+ /**
352
+ * Compress data. Concurrency limited.
353
+ *
354
+ * @param {(Buffer|String)} data Data to compress
355
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
356
+ * @param {Function} callback Callback
357
+ * @public
358
+ */
359
+ compress(e, t, r) {
360
+ K.add((i) => {
361
+ this._compress(e, t, (n, o) => {
362
+ i(), r(n, o);
363
+ });
364
+ });
365
+ }
366
+ /**
367
+ * Decompress data.
368
+ *
369
+ * @param {Buffer} data Compressed data
370
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
371
+ * @param {Function} callback Callback
372
+ * @private
373
+ */
374
+ _decompress(e, t, r) {
375
+ const i = this._isServer ? "client" : "server";
376
+ if (!this._inflate) {
377
+ const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
378
+ this._inflate = W.createInflateRaw({
379
+ ...this._options.zlibInflateOptions,
380
+ windowBits: o
381
+ }), this._inflate[se] = this, this._inflate[w] = 0, this._inflate[C] = [], this._inflate.on("error", Bt), this._inflate.on("data", st);
382
+ }
383
+ this._inflate[V] = r, this._inflate.write(e), t && this._inflate.write(Pt), this._inflate.flush(() => {
384
+ const n = this._inflate[J];
385
+ if (n) {
386
+ this._inflate.close(), this._inflate = null, r(n);
387
+ return;
388
+ }
389
+ const o = Te.concat(
390
+ this._inflate[C],
391
+ this._inflate[w]
392
+ );
393
+ this._inflate._readableState.endEmitted ? (this._inflate.close(), this._inflate = null) : (this._inflate[w] = 0, this._inflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._inflate.reset()), r(null, o);
394
+ });
395
+ }
396
+ /**
397
+ * Compress data.
398
+ *
399
+ * @param {(Buffer|String)} data Data to compress
400
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
401
+ * @param {Function} callback Callback
402
+ * @private
403
+ */
404
+ _compress(e, t, r) {
405
+ const i = this._isServer ? "server" : "client";
406
+ if (!this._deflate) {
407
+ const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
408
+ this._deflate = W.createDeflateRaw({
409
+ ...this._options.zlibDeflateOptions,
410
+ windowBits: o
411
+ }), this._deflate[w] = 0, this._deflate[C] = [], this._deflate.on("data", Ut);
412
+ }
413
+ this._deflate[V] = r, this._deflate.write(e), this._deflate.flush(W.Z_SYNC_FLUSH, () => {
414
+ if (!this._deflate)
415
+ return;
416
+ let n = Te.concat(
417
+ this._deflate[C],
418
+ this._deflate[w]
419
+ );
420
+ t && (n = new Nt(n.buffer, n.byteOffset, n.length - 4)), this._deflate[V] = null, this._deflate[w] = 0, this._deflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._deflate.reset(), r(null, n);
421
+ });
422
+ }
423
+ };
424
+ var oe = Rt;
425
+ function Ut(s) {
426
+ this[C].push(s), this[w] += s.length;
427
+ }
428
+ function st(s) {
429
+ if (this[w] += s.length, this[se]._maxPayload < 1 || this[w] <= this[se]._maxPayload) {
430
+ this[C].push(s);
431
+ return;
432
+ }
433
+ this[J] = new RangeError("Max payload size exceeded"), this[J].code = "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH", this[J][tt] = 1009, this.removeListener("data", st), this.reset();
434
+ }
435
+ function Bt(s) {
436
+ this[se]._inflate = null, s[tt] = 1007, this[V](s);
437
+ }
438
+ var re = { exports: {} };
439
+ const $t = {}, Mt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
440
+ __proto__: null,
441
+ default: $t
442
+ }, Symbol.toStringTag, { value: "Module" })), It = /* @__PURE__ */ gt(Mt);
443
+ var Le;
444
+ const { isUtf8: Ne } = S, Dt = [
445
+ 0,
446
+ 0,
447
+ 0,
448
+ 0,
449
+ 0,
450
+ 0,
451
+ 0,
452
+ 0,
453
+ 0,
454
+ 0,
455
+ 0,
456
+ 0,
457
+ 0,
458
+ 0,
459
+ 0,
460
+ 0,
461
+ // 0 - 15
462
+ 0,
463
+ 0,
464
+ 0,
465
+ 0,
466
+ 0,
467
+ 0,
468
+ 0,
469
+ 0,
470
+ 0,
471
+ 0,
472
+ 0,
473
+ 0,
474
+ 0,
475
+ 0,
476
+ 0,
477
+ 0,
478
+ // 16 - 31
479
+ 0,
480
+ 1,
481
+ 0,
482
+ 1,
483
+ 1,
484
+ 1,
485
+ 1,
486
+ 1,
487
+ 0,
488
+ 0,
489
+ 1,
490
+ 1,
491
+ 0,
492
+ 1,
493
+ 1,
494
+ 0,
495
+ // 32 - 47
496
+ 1,
497
+ 1,
498
+ 1,
499
+ 1,
500
+ 1,
501
+ 1,
502
+ 1,
503
+ 1,
504
+ 1,
505
+ 1,
506
+ 0,
507
+ 0,
508
+ 0,
509
+ 0,
510
+ 0,
511
+ 0,
512
+ // 48 - 63
513
+ 0,
514
+ 1,
515
+ 1,
516
+ 1,
517
+ 1,
518
+ 1,
519
+ 1,
520
+ 1,
521
+ 1,
522
+ 1,
523
+ 1,
524
+ 1,
525
+ 1,
526
+ 1,
527
+ 1,
528
+ 1,
529
+ // 64 - 79
530
+ 1,
531
+ 1,
532
+ 1,
533
+ 1,
534
+ 1,
535
+ 1,
536
+ 1,
537
+ 1,
538
+ 1,
539
+ 1,
540
+ 1,
541
+ 0,
542
+ 0,
543
+ 0,
544
+ 1,
545
+ 1,
546
+ // 80 - 95
547
+ 1,
548
+ 1,
549
+ 1,
550
+ 1,
551
+ 1,
552
+ 1,
553
+ 1,
554
+ 1,
555
+ 1,
556
+ 1,
557
+ 1,
558
+ 1,
559
+ 1,
560
+ 1,
561
+ 1,
562
+ 1,
563
+ // 96 - 111
564
+ 1,
565
+ 1,
566
+ 1,
567
+ 1,
568
+ 1,
569
+ 1,
570
+ 1,
571
+ 1,
572
+ 1,
573
+ 1,
574
+ 1,
575
+ 0,
576
+ 1,
577
+ 0,
578
+ 1,
579
+ 0
580
+ // 112 - 127
581
+ ];
582
+ function Wt(s) {
583
+ return s >= 1e3 && s <= 1014 && s !== 1004 && s !== 1005 && s !== 1006 || s >= 3e3 && s <= 4999;
584
+ }
585
+ function be(s) {
586
+ const e = s.length;
587
+ let t = 0;
588
+ for (; t < e; )
589
+ if (!(s[t] & 128))
590
+ t++;
591
+ else if ((s[t] & 224) === 192) {
592
+ if (t + 1 === e || (s[t + 1] & 192) !== 128 || (s[t] & 254) === 192)
593
+ return !1;
594
+ t += 2;
595
+ } else if ((s[t] & 240) === 224) {
596
+ if (t + 2 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || s[t] === 224 && (s[t + 1] & 224) === 128 || // Overlong
597
+ s[t] === 237 && (s[t + 1] & 224) === 160)
598
+ return !1;
599
+ t += 3;
600
+ } else if ((s[t] & 248) === 240) {
601
+ if (t + 3 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || (s[t + 3] & 192) !== 128 || s[t] === 240 && (s[t + 1] & 240) === 128 || // Overlong
602
+ s[t] === 244 && s[t + 1] > 143 || s[t] > 244)
603
+ return !1;
604
+ t += 4;
605
+ } else
606
+ return !1;
607
+ return !0;
608
+ }
609
+ re.exports = {
610
+ isValidStatusCode: Wt,
611
+ isValidUTF8: be,
612
+ tokenChars: Dt
613
+ };
614
+ if (Ne)
615
+ Le = re.exports.isValidUTF8 = function(s) {
616
+ return s.length < 24 ? be(s) : Ne(s);
617
+ };
618
+ else if (!process.env.WS_NO_UTF_8_VALIDATE)
619
+ try {
620
+ const s = It;
621
+ Le = re.exports.isValidUTF8 = function(e) {
622
+ return e.length < 32 ? be(e) : s(e);
623
+ };
624
+ } catch {
625
+ }
626
+ var ae = re.exports;
627
+ const { Writable: At } = S, Pe = oe, {
628
+ BINARY_TYPES: Ft,
629
+ EMPTY_BUFFER: Re,
630
+ kStatusCode: jt,
631
+ kWebSocket: Gt
632
+ } = U, { concat: de, toArrayBuffer: Vt, unmask: Ht } = ne, { isValidStatusCode: zt, isValidUTF8: Ue } = ae, X = Buffer[Symbol.species], A = 0, Be = 1, $e = 2, Me = 3, _e = 4, Yt = 5;
633
+ let qt = class extends At {
634
+ /**
635
+ * Creates a Receiver instance.
636
+ *
637
+ * @param {Object} [options] Options object
638
+ * @param {String} [options.binaryType=nodebuffer] The type for binary data
639
+ * @param {Object} [options.extensions] An object containing the negotiated
640
+ * extensions
641
+ * @param {Boolean} [options.isServer=false] Specifies whether to operate in
642
+ * client or server mode
643
+ * @param {Number} [options.maxPayload=0] The maximum allowed message length
644
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
645
+ * not to skip UTF-8 validation for text and close messages
646
+ */
647
+ constructor(e = {}) {
648
+ super(), this._binaryType = e.binaryType || Ft[0], this._extensions = e.extensions || {}, this._isServer = !!e.isServer, this._maxPayload = e.maxPayload | 0, this._skipUTF8Validation = !!e.skipUTF8Validation, this[Gt] = void 0, this._bufferedBytes = 0, this._buffers = [], this._compressed = !1, this._payloadLength = 0, this._mask = void 0, this._fragmented = 0, this._masked = !1, this._fin = !1, this._opcode = 0, this._totalPayloadLength = 0, this._messageLength = 0, this._fragments = [], this._state = A, this._loop = !1;
649
+ }
650
+ /**
651
+ * Implements `Writable.prototype._write()`.
652
+ *
653
+ * @param {Buffer} chunk The chunk of data to write
654
+ * @param {String} encoding The character encoding of `chunk`
655
+ * @param {Function} cb Callback
656
+ * @private
657
+ */
658
+ _write(e, t, r) {
659
+ if (this._opcode === 8 && this._state == A)
660
+ return r();
661
+ this._bufferedBytes += e.length, this._buffers.push(e), this.startLoop(r);
662
+ }
663
+ /**
664
+ * Consumes `n` bytes from the buffered data.
665
+ *
666
+ * @param {Number} n The number of bytes to consume
667
+ * @return {Buffer} The consumed bytes
668
+ * @private
669
+ */
670
+ consume(e) {
671
+ if (this._bufferedBytes -= e, e === this._buffers[0].length)
672
+ return this._buffers.shift();
673
+ if (e < this._buffers[0].length) {
674
+ const r = this._buffers[0];
675
+ return this._buffers[0] = new X(
676
+ r.buffer,
677
+ r.byteOffset + e,
678
+ r.length - e
679
+ ), new X(r.buffer, r.byteOffset, e);
680
+ }
681
+ const t = Buffer.allocUnsafe(e);
682
+ do {
683
+ const r = this._buffers[0], i = t.length - e;
684
+ e >= r.length ? t.set(this._buffers.shift(), i) : (t.set(new Uint8Array(r.buffer, r.byteOffset, e), i), this._buffers[0] = new X(
685
+ r.buffer,
686
+ r.byteOffset + e,
687
+ r.length - e
688
+ )), e -= r.length;
689
+ } while (e > 0);
690
+ return t;
691
+ }
692
+ /**
693
+ * Starts the parsing loop.
694
+ *
695
+ * @param {Function} cb Callback
696
+ * @private
697
+ */
698
+ startLoop(e) {
699
+ let t;
700
+ this._loop = !0;
701
+ do
702
+ switch (this._state) {
703
+ case A:
704
+ t = this.getInfo();
705
+ break;
706
+ case Be:
707
+ t = this.getPayloadLength16();
708
+ break;
709
+ case $e:
710
+ t = this.getPayloadLength64();
711
+ break;
712
+ case Me:
713
+ this.getMask();
714
+ break;
715
+ case _e:
716
+ t = this.getData(e);
717
+ break;
718
+ default:
719
+ this._loop = !1;
720
+ return;
721
+ }
722
+ while (this._loop);
723
+ e(t);
724
+ }
725
+ /**
726
+ * Reads the first two bytes of a frame.
727
+ *
728
+ * @return {(RangeError|undefined)} A possible error
729
+ * @private
730
+ */
731
+ getInfo() {
732
+ if (this._bufferedBytes < 2) {
733
+ this._loop = !1;
734
+ return;
735
+ }
736
+ const e = this.consume(2);
737
+ if (e[0] & 48)
738
+ return this._loop = !1, g(
739
+ RangeError,
740
+ "RSV2 and RSV3 must be clear",
741
+ !0,
742
+ 1002,
743
+ "WS_ERR_UNEXPECTED_RSV_2_3"
744
+ );
745
+ const t = (e[0] & 64) === 64;
746
+ if (t && !this._extensions[Pe.extensionName])
747
+ return this._loop = !1, g(
748
+ RangeError,
749
+ "RSV1 must be clear",
750
+ !0,
751
+ 1002,
752
+ "WS_ERR_UNEXPECTED_RSV_1"
753
+ );
754
+ if (this._fin = (e[0] & 128) === 128, this._opcode = e[0] & 15, this._payloadLength = e[1] & 127, this._opcode === 0) {
755
+ if (t)
756
+ return this._loop = !1, g(
757
+ RangeError,
758
+ "RSV1 must be clear",
759
+ !0,
760
+ 1002,
761
+ "WS_ERR_UNEXPECTED_RSV_1"
762
+ );
763
+ if (!this._fragmented)
764
+ return this._loop = !1, g(
765
+ RangeError,
766
+ "invalid opcode 0",
767
+ !0,
768
+ 1002,
769
+ "WS_ERR_INVALID_OPCODE"
770
+ );
771
+ this._opcode = this._fragmented;
772
+ } else if (this._opcode === 1 || this._opcode === 2) {
773
+ if (this._fragmented)
774
+ return this._loop = !1, g(
775
+ RangeError,
776
+ `invalid opcode ${this._opcode}`,
777
+ !0,
778
+ 1002,
779
+ "WS_ERR_INVALID_OPCODE"
780
+ );
781
+ this._compressed = t;
782
+ } else if (this._opcode > 7 && this._opcode < 11) {
783
+ if (!this._fin)
784
+ return this._loop = !1, g(
785
+ RangeError,
786
+ "FIN must be set",
787
+ !0,
788
+ 1002,
789
+ "WS_ERR_EXPECTED_FIN"
790
+ );
791
+ if (t)
792
+ return this._loop = !1, g(
793
+ RangeError,
794
+ "RSV1 must be clear",
795
+ !0,
796
+ 1002,
797
+ "WS_ERR_UNEXPECTED_RSV_1"
798
+ );
799
+ if (this._payloadLength > 125 || this._opcode === 8 && this._payloadLength === 1)
800
+ return this._loop = !1, g(
801
+ RangeError,
802
+ `invalid payload length ${this._payloadLength}`,
803
+ !0,
804
+ 1002,
805
+ "WS_ERR_INVALID_CONTROL_PAYLOAD_LENGTH"
806
+ );
807
+ } else
808
+ return this._loop = !1, g(
809
+ RangeError,
810
+ `invalid opcode ${this._opcode}`,
811
+ !0,
812
+ 1002,
813
+ "WS_ERR_INVALID_OPCODE"
814
+ );
815
+ if (!this._fin && !this._fragmented && (this._fragmented = this._opcode), this._masked = (e[1] & 128) === 128, this._isServer) {
816
+ if (!this._masked)
817
+ return this._loop = !1, g(
818
+ RangeError,
819
+ "MASK must be set",
820
+ !0,
821
+ 1002,
822
+ "WS_ERR_EXPECTED_MASK"
823
+ );
824
+ } else if (this._masked)
825
+ return this._loop = !1, g(
826
+ RangeError,
827
+ "MASK must be clear",
828
+ !0,
829
+ 1002,
830
+ "WS_ERR_UNEXPECTED_MASK"
831
+ );
832
+ if (this._payloadLength === 126)
833
+ this._state = Be;
834
+ else if (this._payloadLength === 127)
835
+ this._state = $e;
836
+ else
837
+ return this.haveLength();
838
+ }
839
+ /**
840
+ * Gets extended payload length (7+16).
841
+ *
842
+ * @return {(RangeError|undefined)} A possible error
843
+ * @private
844
+ */
845
+ getPayloadLength16() {
846
+ if (this._bufferedBytes < 2) {
847
+ this._loop = !1;
848
+ return;
849
+ }
850
+ return this._payloadLength = this.consume(2).readUInt16BE(0), this.haveLength();
851
+ }
852
+ /**
853
+ * Gets extended payload length (7+64).
854
+ *
855
+ * @return {(RangeError|undefined)} A possible error
856
+ * @private
857
+ */
858
+ getPayloadLength64() {
859
+ if (this._bufferedBytes < 8) {
860
+ this._loop = !1;
861
+ return;
862
+ }
863
+ const e = this.consume(8), t = e.readUInt32BE(0);
864
+ return t > Math.pow(2, 21) - 1 ? (this._loop = !1, g(
865
+ RangeError,
866
+ "Unsupported WebSocket frame: payload length > 2^53 - 1",
867
+ !1,
868
+ 1009,
869
+ "WS_ERR_UNSUPPORTED_DATA_PAYLOAD_LENGTH"
870
+ )) : (this._payloadLength = t * Math.pow(2, 32) + e.readUInt32BE(4), this.haveLength());
871
+ }
872
+ /**
873
+ * Payload length has been read.
874
+ *
875
+ * @return {(RangeError|undefined)} A possible error
876
+ * @private
877
+ */
878
+ haveLength() {
879
+ if (this._payloadLength && this._opcode < 8 && (this._totalPayloadLength += this._payloadLength, this._totalPayloadLength > this._maxPayload && this._maxPayload > 0))
880
+ return this._loop = !1, g(
881
+ RangeError,
882
+ "Max payload size exceeded",
883
+ !1,
884
+ 1009,
885
+ "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
886
+ );
887
+ this._masked ? this._state = Me : this._state = _e;
888
+ }
889
+ /**
890
+ * Reads mask bytes.
891
+ *
892
+ * @private
893
+ */
894
+ getMask() {
895
+ if (this._bufferedBytes < 4) {
896
+ this._loop = !1;
897
+ return;
898
+ }
899
+ this._mask = this.consume(4), this._state = _e;
900
+ }
901
+ /**
902
+ * Reads data bytes.
903
+ *
904
+ * @param {Function} cb Callback
905
+ * @return {(Error|RangeError|undefined)} A possible error
906
+ * @private
907
+ */
908
+ getData(e) {
909
+ let t = Re;
910
+ if (this._payloadLength) {
911
+ if (this._bufferedBytes < this._payloadLength) {
912
+ this._loop = !1;
913
+ return;
914
+ }
915
+ t = this.consume(this._payloadLength), this._masked && this._mask[0] | this._mask[1] | this._mask[2] | this._mask[3] && Ht(t, this._mask);
916
+ }
917
+ if (this._opcode > 7)
918
+ return this.controlMessage(t);
919
+ if (this._compressed) {
920
+ this._state = Yt, this.decompress(t, e);
921
+ return;
922
+ }
923
+ return t.length && (this._messageLength = this._totalPayloadLength, this._fragments.push(t)), this.dataMessage();
924
+ }
925
+ /**
926
+ * Decompresses data.
927
+ *
928
+ * @param {Buffer} data Compressed data
929
+ * @param {Function} cb Callback
930
+ * @private
931
+ */
932
+ decompress(e, t) {
933
+ this._extensions[Pe.extensionName].decompress(e, this._fin, (i, n) => {
934
+ if (i)
935
+ return t(i);
936
+ if (n.length) {
937
+ if (this._messageLength += n.length, this._messageLength > this._maxPayload && this._maxPayload > 0)
938
+ return t(
939
+ g(
940
+ RangeError,
941
+ "Max payload size exceeded",
942
+ !1,
943
+ 1009,
944
+ "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
945
+ )
946
+ );
947
+ this._fragments.push(n);
948
+ }
949
+ const o = this.dataMessage();
950
+ if (o)
951
+ return t(o);
952
+ this.startLoop(t);
953
+ });
954
+ }
955
+ /**
956
+ * Handles a data message.
957
+ *
958
+ * @return {(Error|undefined)} A possible error
959
+ * @private
960
+ */
961
+ dataMessage() {
962
+ if (this._fin) {
963
+ const e = this._messageLength, t = this._fragments;
964
+ if (this._totalPayloadLength = 0, this._messageLength = 0, this._fragmented = 0, this._fragments = [], this._opcode === 2) {
965
+ let r;
966
+ this._binaryType === "nodebuffer" ? r = de(t, e) : this._binaryType === "arraybuffer" ? r = Vt(de(t, e)) : r = t, this.emit("message", r, !0);
967
+ } else {
968
+ const r = de(t, e);
969
+ if (!this._skipUTF8Validation && !Ue(r))
970
+ return this._loop = !1, g(
971
+ Error,
972
+ "invalid UTF-8 sequence",
973
+ !0,
974
+ 1007,
975
+ "WS_ERR_INVALID_UTF8"
976
+ );
977
+ this.emit("message", r, !1);
978
+ }
979
+ }
980
+ this._state = A;
981
+ }
982
+ /**
983
+ * Handles a control message.
984
+ *
985
+ * @param {Buffer} data Data to handle
986
+ * @return {(Error|RangeError|undefined)} A possible error
987
+ * @private
988
+ */
989
+ controlMessage(e) {
990
+ if (this._opcode === 8)
991
+ if (this._loop = !1, e.length === 0)
992
+ this.emit("conclude", 1005, Re), this.end();
993
+ else {
994
+ const t = e.readUInt16BE(0);
995
+ if (!zt(t))
996
+ return g(
997
+ RangeError,
998
+ `invalid status code ${t}`,
999
+ !0,
1000
+ 1002,
1001
+ "WS_ERR_INVALID_CLOSE_CODE"
1002
+ );
1003
+ const r = new X(
1004
+ e.buffer,
1005
+ e.byteOffset + 2,
1006
+ e.length - 2
1007
+ );
1008
+ if (!this._skipUTF8Validation && !Ue(r))
1009
+ return g(
1010
+ Error,
1011
+ "invalid UTF-8 sequence",
1012
+ !0,
1013
+ 1007,
1014
+ "WS_ERR_INVALID_UTF8"
1015
+ );
1016
+ this.emit("conclude", t, r), this.end();
1017
+ }
1018
+ else
1019
+ this._opcode === 9 ? this.emit("ping", e) : this.emit("pong", e);
1020
+ this._state = A;
1021
+ }
1022
+ };
1023
+ var rt = qt;
1024
+ function g(s, e, t, r, i) {
1025
+ const n = new s(
1026
+ t ? `Invalid WebSocket frame: ${e}` : e
1027
+ );
1028
+ return Error.captureStackTrace(n, g), n.code = i, n[jt] = r, n;
1029
+ }
1030
+ const qs = /* @__PURE__ */ z(rt), { randomFillSync: Kt } = S, Ie = oe, { EMPTY_BUFFER: Xt } = U, { isValidStatusCode: Zt } = ae, { mask: De, toBuffer: M } = ne, x = Symbol("kByteLength"), Qt = Buffer.alloc(4);
1031
+ let Jt = class P {
1032
+ /**
1033
+ * Creates a Sender instance.
1034
+ *
1035
+ * @param {(net.Socket|tls.Socket)} socket The connection socket
1036
+ * @param {Object} [extensions] An object containing the negotiated extensions
1037
+ * @param {Function} [generateMask] The function used to generate the masking
1038
+ * key
1039
+ */
1040
+ constructor(e, t, r) {
1041
+ this._extensions = t || {}, r && (this._generateMask = r, this._maskBuffer = Buffer.alloc(4)), this._socket = e, this._firstFragment = !0, this._compress = !1, this._bufferedBytes = 0, this._deflating = !1, this._queue = [];
1042
+ }
1043
+ /**
1044
+ * Frames a piece of data according to the HyBi WebSocket protocol.
1045
+ *
1046
+ * @param {(Buffer|String)} data The data to frame
1047
+ * @param {Object} options Options object
1048
+ * @param {Boolean} [options.fin=false] Specifies whether or not to set the
1049
+ * FIN bit
1050
+ * @param {Function} [options.generateMask] The function used to generate the
1051
+ * masking key
1052
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1053
+ * `data`
1054
+ * @param {Buffer} [options.maskBuffer] The buffer used to store the masking
1055
+ * key
1056
+ * @param {Number} options.opcode The opcode
1057
+ * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
1058
+ * modified
1059
+ * @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
1060
+ * RSV1 bit
1061
+ * @return {(Buffer|String)[]} The framed data
1062
+ * @public
1063
+ */
1064
+ static frame(e, t) {
1065
+ let r, i = !1, n = 2, o = !1;
1066
+ t.mask && (r = t.maskBuffer || Qt, t.generateMask ? t.generateMask(r) : Kt(r, 0, 4), o = (r[0] | r[1] | r[2] | r[3]) === 0, n = 6);
1067
+ let l;
1068
+ typeof e == "string" ? (!t.mask || o) && t[x] !== void 0 ? l = t[x] : (e = Buffer.from(e), l = e.length) : (l = e.length, i = t.mask && t.readOnly && !o);
1069
+ let f = l;
1070
+ l >= 65536 ? (n += 8, f = 127) : l > 125 && (n += 2, f = 126);
1071
+ const a = Buffer.allocUnsafe(i ? l + n : n);
1072
+ return a[0] = t.fin ? t.opcode | 128 : t.opcode, t.rsv1 && (a[0] |= 64), a[1] = f, f === 126 ? a.writeUInt16BE(l, 2) : f === 127 && (a[2] = a[3] = 0, a.writeUIntBE(l, 4, 6)), t.mask ? (a[1] |= 128, a[n - 4] = r[0], a[n - 3] = r[1], a[n - 2] = r[2], a[n - 1] = r[3], o ? [a, e] : i ? (De(e, r, a, n, l), [a]) : (De(e, r, e, 0, l), [a, e])) : [a, e];
1073
+ }
1074
+ /**
1075
+ * Sends a close message to the other peer.
1076
+ *
1077
+ * @param {Number} [code] The status code component of the body
1078
+ * @param {(String|Buffer)} [data] The message component of the body
1079
+ * @param {Boolean} [mask=false] Specifies whether or not to mask the message
1080
+ * @param {Function} [cb] Callback
1081
+ * @public
1082
+ */
1083
+ close(e, t, r, i) {
1084
+ let n;
1085
+ if (e === void 0)
1086
+ n = Xt;
1087
+ else {
1088
+ if (typeof e != "number" || !Zt(e))
1089
+ throw new TypeError("First argument must be a valid error code number");
1090
+ if (t === void 0 || !t.length)
1091
+ n = Buffer.allocUnsafe(2), n.writeUInt16BE(e, 0);
1092
+ else {
1093
+ const l = Buffer.byteLength(t);
1094
+ if (l > 123)
1095
+ throw new RangeError("The message must not be greater than 123 bytes");
1096
+ n = Buffer.allocUnsafe(2 + l), n.writeUInt16BE(e, 0), typeof t == "string" ? n.write(t, 2) : n.set(t, 2);
1097
+ }
1098
+ }
1099
+ const o = {
1100
+ [x]: n.length,
1101
+ fin: !0,
1102
+ generateMask: this._generateMask,
1103
+ mask: r,
1104
+ maskBuffer: this._maskBuffer,
1105
+ opcode: 8,
1106
+ readOnly: !1,
1107
+ rsv1: !1
1108
+ };
1109
+ this._deflating ? this.enqueue([this.dispatch, n, !1, o, i]) : this.sendFrame(P.frame(n, o), i);
1110
+ }
1111
+ /**
1112
+ * Sends a ping message to the other peer.
1113
+ *
1114
+ * @param {*} data The message to send
1115
+ * @param {Boolean} [mask=false] Specifies whether or not to mask `data`
1116
+ * @param {Function} [cb] Callback
1117
+ * @public
1118
+ */
1119
+ ping(e, t, r) {
1120
+ let i, n;
1121
+ if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
1122
+ throw new RangeError("The data size must not be greater than 125 bytes");
1123
+ const o = {
1124
+ [x]: i,
1125
+ fin: !0,
1126
+ generateMask: this._generateMask,
1127
+ mask: t,
1128
+ maskBuffer: this._maskBuffer,
1129
+ opcode: 9,
1130
+ readOnly: n,
1131
+ rsv1: !1
1132
+ };
1133
+ this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
1134
+ }
1135
+ /**
1136
+ * Sends a pong message to the other peer.
1137
+ *
1138
+ * @param {*} data The message to send
1139
+ * @param {Boolean} [mask=false] Specifies whether or not to mask `data`
1140
+ * @param {Function} [cb] Callback
1141
+ * @public
1142
+ */
1143
+ pong(e, t, r) {
1144
+ let i, n;
1145
+ if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
1146
+ throw new RangeError("The data size must not be greater than 125 bytes");
1147
+ const o = {
1148
+ [x]: i,
1149
+ fin: !0,
1150
+ generateMask: this._generateMask,
1151
+ mask: t,
1152
+ maskBuffer: this._maskBuffer,
1153
+ opcode: 10,
1154
+ readOnly: n,
1155
+ rsv1: !1
1156
+ };
1157
+ this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
1158
+ }
1159
+ /**
1160
+ * Sends a data message to the other peer.
1161
+ *
1162
+ * @param {*} data The message to send
1163
+ * @param {Object} options Options object
1164
+ * @param {Boolean} [options.binary=false] Specifies whether `data` is binary
1165
+ * or text
1166
+ * @param {Boolean} [options.compress=false] Specifies whether or not to
1167
+ * compress `data`
1168
+ * @param {Boolean} [options.fin=false] Specifies whether the fragment is the
1169
+ * last one
1170
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1171
+ * `data`
1172
+ * @param {Function} [cb] Callback
1173
+ * @public
1174
+ */
1175
+ send(e, t, r) {
1176
+ const i = this._extensions[Ie.extensionName];
1177
+ let n = t.binary ? 2 : 1, o = t.compress, l, f;
1178
+ if (typeof e == "string" ? (l = Buffer.byteLength(e), f = !1) : (e = M(e), l = e.length, f = M.readOnly), this._firstFragment ? (this._firstFragment = !1, o && i && i.params[i._isServer ? "server_no_context_takeover" : "client_no_context_takeover"] && (o = l >= i._threshold), this._compress = o) : (o = !1, n = 0), t.fin && (this._firstFragment = !0), i) {
1179
+ const a = {
1180
+ [x]: l,
1181
+ fin: t.fin,
1182
+ generateMask: this._generateMask,
1183
+ mask: t.mask,
1184
+ maskBuffer: this._maskBuffer,
1185
+ opcode: n,
1186
+ readOnly: f,
1187
+ rsv1: o
1188
+ };
1189
+ this._deflating ? this.enqueue([this.dispatch, e, this._compress, a, r]) : this.dispatch(e, this._compress, a, r);
1190
+ } else
1191
+ this.sendFrame(
1192
+ P.frame(e, {
1193
+ [x]: l,
1194
+ fin: t.fin,
1195
+ generateMask: this._generateMask,
1196
+ mask: t.mask,
1197
+ maskBuffer: this._maskBuffer,
1198
+ opcode: n,
1199
+ readOnly: f,
1200
+ rsv1: !1
1201
+ }),
1202
+ r
1203
+ );
1204
+ }
1205
+ /**
1206
+ * Dispatches a message.
1207
+ *
1208
+ * @param {(Buffer|String)} data The message to send
1209
+ * @param {Boolean} [compress=false] Specifies whether or not to compress
1210
+ * `data`
1211
+ * @param {Object} options Options object
1212
+ * @param {Boolean} [options.fin=false] Specifies whether or not to set the
1213
+ * FIN bit
1214
+ * @param {Function} [options.generateMask] The function used to generate the
1215
+ * masking key
1216
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1217
+ * `data`
1218
+ * @param {Buffer} [options.maskBuffer] The buffer used to store the masking
1219
+ * key
1220
+ * @param {Number} options.opcode The opcode
1221
+ * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
1222
+ * modified
1223
+ * @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
1224
+ * RSV1 bit
1225
+ * @param {Function} [cb] Callback
1226
+ * @private
1227
+ */
1228
+ dispatch(e, t, r, i) {
1229
+ if (!t) {
1230
+ this.sendFrame(P.frame(e, r), i);
1231
+ return;
1232
+ }
1233
+ const n = this._extensions[Ie.extensionName];
1234
+ this._bufferedBytes += r[x], this._deflating = !0, n.compress(e, r.fin, (o, l) => {
1235
+ if (this._socket.destroyed) {
1236
+ const f = new Error(
1237
+ "The socket was closed while data was being compressed"
1238
+ );
1239
+ typeof i == "function" && i(f);
1240
+ for (let a = 0; a < this._queue.length; a++) {
1241
+ const c = this._queue[a], h = c[c.length - 1];
1242
+ typeof h == "function" && h(f);
1243
+ }
1244
+ return;
1245
+ }
1246
+ this._bufferedBytes -= r[x], this._deflating = !1, r.readOnly = !1, this.sendFrame(P.frame(l, r), i), this.dequeue();
1247
+ });
1248
+ }
1249
+ /**
1250
+ * Executes queued send operations.
1251
+ *
1252
+ * @private
1253
+ */
1254
+ dequeue() {
1255
+ for (; !this._deflating && this._queue.length; ) {
1256
+ const e = this._queue.shift();
1257
+ this._bufferedBytes -= e[3][x], Reflect.apply(e[0], this, e.slice(1));
1258
+ }
1259
+ }
1260
+ /**
1261
+ * Enqueues a send operation.
1262
+ *
1263
+ * @param {Array} params Send operation parameters.
1264
+ * @private
1265
+ */
1266
+ enqueue(e) {
1267
+ this._bufferedBytes += e[3][x], this._queue.push(e);
1268
+ }
1269
+ /**
1270
+ * Sends a frame.
1271
+ *
1272
+ * @param {Buffer[]} list The frame to send
1273
+ * @param {Function} [cb] Callback
1274
+ * @private
1275
+ */
1276
+ sendFrame(e, t) {
1277
+ e.length === 2 ? (this._socket.cork(), this._socket.write(e[0]), this._socket.write(e[1], t), this._socket.uncork()) : this._socket.write(e[0], t);
1278
+ }
1279
+ };
1280
+ var it = Jt;
1281
+ const Ks = /* @__PURE__ */ z(it), { kForOnEventAttribute: F, kListener: pe } = U, We = Symbol("kCode"), Ae = Symbol("kData"), Fe = Symbol("kError"), je = Symbol("kMessage"), Ge = Symbol("kReason"), I = Symbol("kTarget"), Ve = Symbol("kType"), He = Symbol("kWasClean");
1282
+ class B {
1283
+ /**
1284
+ * Create a new `Event`.
1285
+ *
1286
+ * @param {String} type The name of the event
1287
+ * @throws {TypeError} If the `type` argument is not specified
1288
+ */
1289
+ constructor(e) {
1290
+ this[I] = null, this[Ve] = e;
1291
+ }
1292
+ /**
1293
+ * @type {*}
1294
+ */
1295
+ get target() {
1296
+ return this[I];
1297
+ }
1298
+ /**
1299
+ * @type {String}
1300
+ */
1301
+ get type() {
1302
+ return this[Ve];
1303
+ }
1304
+ }
1305
+ Object.defineProperty(B.prototype, "target", { enumerable: !0 });
1306
+ Object.defineProperty(B.prototype, "type", { enumerable: !0 });
1307
+ class Y extends B {
1308
+ /**
1309
+ * Create a new `CloseEvent`.
1310
+ *
1311
+ * @param {String} type The name of the event
1312
+ * @param {Object} [options] A dictionary object that allows for setting
1313
+ * attributes via object members of the same name
1314
+ * @param {Number} [options.code=0] The status code explaining why the
1315
+ * connection was closed
1316
+ * @param {String} [options.reason=''] A human-readable string explaining why
1317
+ * the connection was closed
1318
+ * @param {Boolean} [options.wasClean=false] Indicates whether or not the
1319
+ * connection was cleanly closed
1320
+ */
1321
+ constructor(e, t = {}) {
1322
+ super(e), this[We] = t.code === void 0 ? 0 : t.code, this[Ge] = t.reason === void 0 ? "" : t.reason, this[He] = t.wasClean === void 0 ? !1 : t.wasClean;
1323
+ }
1324
+ /**
1325
+ * @type {Number}
1326
+ */
1327
+ get code() {
1328
+ return this[We];
1329
+ }
1330
+ /**
1331
+ * @type {String}
1332
+ */
1333
+ get reason() {
1334
+ return this[Ge];
1335
+ }
1336
+ /**
1337
+ * @type {Boolean}
1338
+ */
1339
+ get wasClean() {
1340
+ return this[He];
1341
+ }
1342
+ }
1343
+ Object.defineProperty(Y.prototype, "code", { enumerable: !0 });
1344
+ Object.defineProperty(Y.prototype, "reason", { enumerable: !0 });
1345
+ Object.defineProperty(Y.prototype, "wasClean", { enumerable: !0 });
1346
+ class le extends B {
1347
+ /**
1348
+ * Create a new `ErrorEvent`.
1349
+ *
1350
+ * @param {String} type The name of the event
1351
+ * @param {Object} [options] A dictionary object that allows for setting
1352
+ * attributes via object members of the same name
1353
+ * @param {*} [options.error=null] The error that generated this event
1354
+ * @param {String} [options.message=''] The error message
1355
+ */
1356
+ constructor(e, t = {}) {
1357
+ super(e), this[Fe] = t.error === void 0 ? null : t.error, this[je] = t.message === void 0 ? "" : t.message;
1358
+ }
1359
+ /**
1360
+ * @type {*}
1361
+ */
1362
+ get error() {
1363
+ return this[Fe];
1364
+ }
1365
+ /**
1366
+ * @type {String}
1367
+ */
1368
+ get message() {
1369
+ return this[je];
1370
+ }
1371
+ }
1372
+ Object.defineProperty(le.prototype, "error", { enumerable: !0 });
1373
+ Object.defineProperty(le.prototype, "message", { enumerable: !0 });
1374
+ class xe extends B {
1375
+ /**
1376
+ * Create a new `MessageEvent`.
1377
+ *
1378
+ * @param {String} type The name of the event
1379
+ * @param {Object} [options] A dictionary object that allows for setting
1380
+ * attributes via object members of the same name
1381
+ * @param {*} [options.data=null] The message content
1382
+ */
1383
+ constructor(e, t = {}) {
1384
+ super(e), this[Ae] = t.data === void 0 ? null : t.data;
1385
+ }
1386
+ /**
1387
+ * @type {*}
1388
+ */
1389
+ get data() {
1390
+ return this[Ae];
1391
+ }
1392
+ }
1393
+ Object.defineProperty(xe.prototype, "data", { enumerable: !0 });
1394
+ const es = {
1395
+ /**
1396
+ * Register an event listener.
1397
+ *
1398
+ * @param {String} type A string representing the event type to listen for
1399
+ * @param {(Function|Object)} handler The listener to add
1400
+ * @param {Object} [options] An options object specifies characteristics about
1401
+ * the event listener
1402
+ * @param {Boolean} [options.once=false] A `Boolean` indicating that the
1403
+ * listener should be invoked at most once after being added. If `true`,
1404
+ * the listener would be automatically removed when invoked.
1405
+ * @public
1406
+ */
1407
+ addEventListener(s, e, t = {}) {
1408
+ for (const i of this.listeners(s))
1409
+ if (!t[F] && i[pe] === e && !i[F])
1410
+ return;
1411
+ let r;
1412
+ if (s === "message")
1413
+ r = function(n, o) {
1414
+ const l = new xe("message", {
1415
+ data: o ? n : n.toString()
1416
+ });
1417
+ l[I] = this, Z(e, this, l);
1418
+ };
1419
+ else if (s === "close")
1420
+ r = function(n, o) {
1421
+ const l = new Y("close", {
1422
+ code: n,
1423
+ reason: o.toString(),
1424
+ wasClean: this._closeFrameReceived && this._closeFrameSent
1425
+ });
1426
+ l[I] = this, Z(e, this, l);
1427
+ };
1428
+ else if (s === "error")
1429
+ r = function(n) {
1430
+ const o = new le("error", {
1431
+ error: n,
1432
+ message: n.message
1433
+ });
1434
+ o[I] = this, Z(e, this, o);
1435
+ };
1436
+ else if (s === "open")
1437
+ r = function() {
1438
+ const n = new B("open");
1439
+ n[I] = this, Z(e, this, n);
1440
+ };
1441
+ else
1442
+ return;
1443
+ r[F] = !!t[F], r[pe] = e, t.once ? this.once(s, r) : this.on(s, r);
1444
+ },
1445
+ /**
1446
+ * Remove an event listener.
1447
+ *
1448
+ * @param {String} type A string representing the event type to remove
1449
+ * @param {(Function|Object)} handler The listener to remove
1450
+ * @public
1451
+ */
1452
+ removeEventListener(s, e) {
1453
+ for (const t of this.listeners(s))
1454
+ if (t[pe] === e && !t[F]) {
1455
+ this.removeListener(s, t);
1456
+ break;
1457
+ }
1458
+ }
1459
+ };
1460
+ var ts = {
1461
+ CloseEvent: Y,
1462
+ ErrorEvent: le,
1463
+ Event: B,
1464
+ EventTarget: es,
1465
+ MessageEvent: xe
1466
+ };
1467
+ function Z(s, e, t) {
1468
+ typeof s == "object" && s.handleEvent ? s.handleEvent.call(s, t) : s.call(e, t);
1469
+ }
1470
+ const { tokenChars: j } = ae;
1471
+ function k(s, e, t) {
1472
+ s[e] === void 0 ? s[e] = [t] : s[e].push(t);
1473
+ }
1474
+ function ss(s) {
1475
+ const e = /* @__PURE__ */ Object.create(null);
1476
+ let t = /* @__PURE__ */ Object.create(null), r = !1, i = !1, n = !1, o, l, f = -1, a = -1, c = -1, h = 0;
1477
+ for (; h < s.length; h++)
1478
+ if (a = s.charCodeAt(h), o === void 0)
1479
+ if (c === -1 && j[a] === 1)
1480
+ f === -1 && (f = h);
1481
+ else if (h !== 0 && (a === 32 || a === 9))
1482
+ c === -1 && f !== -1 && (c = h);
1483
+ else if (a === 59 || a === 44) {
1484
+ if (f === -1)
1485
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1486
+ c === -1 && (c = h);
1487
+ const v = s.slice(f, c);
1488
+ a === 44 ? (k(e, v, t), t = /* @__PURE__ */ Object.create(null)) : o = v, f = c = -1;
1489
+ } else
1490
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1491
+ else if (l === void 0)
1492
+ if (c === -1 && j[a] === 1)
1493
+ f === -1 && (f = h);
1494
+ else if (a === 32 || a === 9)
1495
+ c === -1 && f !== -1 && (c = h);
1496
+ else if (a === 59 || a === 44) {
1497
+ if (f === -1)
1498
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1499
+ c === -1 && (c = h), k(t, s.slice(f, c), !0), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), f = c = -1;
1500
+ } else if (a === 61 && f !== -1 && c === -1)
1501
+ l = s.slice(f, h), f = c = -1;
1502
+ else
1503
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1504
+ else if (i) {
1505
+ if (j[a] !== 1)
1506
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1507
+ f === -1 ? f = h : r || (r = !0), i = !1;
1508
+ } else if (n)
1509
+ if (j[a] === 1)
1510
+ f === -1 && (f = h);
1511
+ else if (a === 34 && f !== -1)
1512
+ n = !1, c = h;
1513
+ else if (a === 92)
1514
+ i = !0;
1515
+ else
1516
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1517
+ else if (a === 34 && s.charCodeAt(h - 1) === 61)
1518
+ n = !0;
1519
+ else if (c === -1 && j[a] === 1)
1520
+ f === -1 && (f = h);
1521
+ else if (f !== -1 && (a === 32 || a === 9))
1522
+ c === -1 && (c = h);
1523
+ else if (a === 59 || a === 44) {
1524
+ if (f === -1)
1525
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1526
+ c === -1 && (c = h);
1527
+ let v = s.slice(f, c);
1528
+ r && (v = v.replace(/\\/g, ""), r = !1), k(t, l, v), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), l = void 0, f = c = -1;
1529
+ } else
1530
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1531
+ if (f === -1 || n || a === 32 || a === 9)
1532
+ throw new SyntaxError("Unexpected end of input");
1533
+ c === -1 && (c = h);
1534
+ const p = s.slice(f, c);
1535
+ return o === void 0 ? k(e, p, t) : (l === void 0 ? k(t, p, !0) : r ? k(t, l, p.replace(/\\/g, "")) : k(t, l, p), k(e, o, t)), e;
1536
+ }
1537
+ function rs(s) {
1538
+ return Object.keys(s).map((e) => {
1539
+ let t = s[e];
1540
+ return Array.isArray(t) || (t = [t]), t.map((r) => [e].concat(
1541
+ Object.keys(r).map((i) => {
1542
+ let n = r[i];
1543
+ return Array.isArray(n) || (n = [n]), n.map((o) => o === !0 ? i : `${i}=${o}`).join("; ");
1544
+ })
1545
+ ).join("; ")).join(", ");
1546
+ }).join(", ");
1547
+ }
1548
+ var nt = { format: rs, parse: ss };
1549
+ const is = S, ns = S, os = S, ot = S, as = S, { randomBytes: ls, createHash: fs } = S, { URL: me } = S, T = oe, hs = rt, cs = it, {
1550
+ BINARY_TYPES: ze,
1551
+ EMPTY_BUFFER: Q,
1552
+ GUID: us,
1553
+ kForOnEventAttribute: ge,
1554
+ kListener: ds,
1555
+ kStatusCode: _s,
1556
+ kWebSocket: y,
1557
+ NOOP: at
1558
+ } = U, {
1559
+ EventTarget: { addEventListener: ps, removeEventListener: ms }
1560
+ } = ts, { format: gs, parse: ys } = nt, { toBuffer: vs } = ne, Ss = 30 * 1e3, lt = Symbol("kAborted"), ye = [8, 13], O = ["CONNECTING", "OPEN", "CLOSING", "CLOSED"], Es = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/;
1561
+ let m = class d extends is {
1562
+ /**
1563
+ * Create a new `WebSocket`.
1564
+ *
1565
+ * @param {(String|URL)} address The URL to which to connect
1566
+ * @param {(String|String[])} [protocols] The subprotocols
1567
+ * @param {Object} [options] Connection options
1568
+ */
1569
+ constructor(e, t, r) {
1570
+ super(), this._binaryType = ze[0], this._closeCode = 1006, this._closeFrameReceived = !1, this._closeFrameSent = !1, this._closeMessage = Q, this._closeTimer = null, this._extensions = {}, this._paused = !1, this._protocol = "", this._readyState = d.CONNECTING, this._receiver = null, this._sender = null, this._socket = null, e !== null ? (this._bufferedAmount = 0, this._isServer = !1, this._redirects = 0, t === void 0 ? t = [] : Array.isArray(t) || (typeof t == "object" && t !== null ? (r = t, t = []) : t = [t]), ht(this, e, t, r)) : this._isServer = !0;
1571
+ }
1572
+ /**
1573
+ * This deviates from the WHATWG interface since ws doesn't support the
1574
+ * required default "blob" type (instead we define a custom "nodebuffer"
1575
+ * type).
1576
+ *
1577
+ * @type {String}
1578
+ */
1579
+ get binaryType() {
1580
+ return this._binaryType;
1581
+ }
1582
+ set binaryType(e) {
1583
+ ze.includes(e) && (this._binaryType = e, this._receiver && (this._receiver._binaryType = e));
1584
+ }
1585
+ /**
1586
+ * @type {Number}
1587
+ */
1588
+ get bufferedAmount() {
1589
+ return this._socket ? this._socket._writableState.length + this._sender._bufferedBytes : this._bufferedAmount;
1590
+ }
1591
+ /**
1592
+ * @type {String}
1593
+ */
1594
+ get extensions() {
1595
+ return Object.keys(this._extensions).join();
1596
+ }
1597
+ /**
1598
+ * @type {Boolean}
1599
+ */
1600
+ get isPaused() {
1601
+ return this._paused;
1602
+ }
1603
+ /**
1604
+ * @type {Function}
1605
+ */
1606
+ /* istanbul ignore next */
1607
+ get onclose() {
1608
+ return null;
1609
+ }
1610
+ /**
1611
+ * @type {Function}
1612
+ */
1613
+ /* istanbul ignore next */
1614
+ get onerror() {
1615
+ return null;
1616
+ }
1617
+ /**
1618
+ * @type {Function}
1619
+ */
1620
+ /* istanbul ignore next */
1621
+ get onopen() {
1622
+ return null;
1623
+ }
1624
+ /**
1625
+ * @type {Function}
1626
+ */
1627
+ /* istanbul ignore next */
1628
+ get onmessage() {
1629
+ return null;
1630
+ }
1631
+ /**
1632
+ * @type {String}
1633
+ */
1634
+ get protocol() {
1635
+ return this._protocol;
1636
+ }
1637
+ /**
1638
+ * @type {Number}
1639
+ */
1640
+ get readyState() {
1641
+ return this._readyState;
1642
+ }
1643
+ /**
1644
+ * @type {String}
1645
+ */
1646
+ get url() {
1647
+ return this._url;
1648
+ }
1649
+ /**
1650
+ * Set up the socket and the internal resources.
1651
+ *
1652
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
1653
+ * server and client
1654
+ * @param {Buffer} head The first packet of the upgraded stream
1655
+ * @param {Object} options Options object
1656
+ * @param {Function} [options.generateMask] The function used to generate the
1657
+ * masking key
1658
+ * @param {Number} [options.maxPayload=0] The maximum allowed message size
1659
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
1660
+ * not to skip UTF-8 validation for text and close messages
1661
+ * @private
1662
+ */
1663
+ setSocket(e, t, r) {
1664
+ const i = new hs({
1665
+ binaryType: this.binaryType,
1666
+ extensions: this._extensions,
1667
+ isServer: this._isServer,
1668
+ maxPayload: r.maxPayload,
1669
+ skipUTF8Validation: r.skipUTF8Validation
1670
+ });
1671
+ this._sender = new cs(e, this._extensions, r.generateMask), this._receiver = i, this._socket = e, i[y] = this, e[y] = this, i.on("conclude", ks), i.on("drain", ws), i.on("error", Os), i.on("message", Cs), i.on("ping", Ts), i.on("pong", Ls), e.setTimeout(0), e.setNoDelay(), t.length > 0 && e.unshift(t), e.on("close", ut), e.on("data", fe), e.on("end", dt), e.on("error", _t), this._readyState = d.OPEN, this.emit("open");
1672
+ }
1673
+ /**
1674
+ * Emit the `'close'` event.
1675
+ *
1676
+ * @private
1677
+ */
1678
+ emitClose() {
1679
+ if (!this._socket) {
1680
+ this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
1681
+ return;
1682
+ }
1683
+ this._extensions[T.extensionName] && this._extensions[T.extensionName].cleanup(), this._receiver.removeAllListeners(), this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
1684
+ }
1685
+ /**
1686
+ * Start a closing handshake.
1687
+ *
1688
+ * +----------+ +-----------+ +----------+
1689
+ * - - -|ws.close()|-->|close frame|-->|ws.close()|- - -
1690
+ * | +----------+ +-----------+ +----------+ |
1691
+ * +----------+ +-----------+ |
1692
+ * CLOSING |ws.close()|<--|close frame|<--+-----+ CLOSING
1693
+ * +----------+ +-----------+ |
1694
+ * | | | +---+ |
1695
+ * +------------------------+-->|fin| - - - -
1696
+ * | +---+ | +---+
1697
+ * - - - - -|fin|<---------------------+
1698
+ * +---+
1699
+ *
1700
+ * @param {Number} [code] Status code explaining why the connection is closing
1701
+ * @param {(String|Buffer)} [data] The reason why the connection is
1702
+ * closing
1703
+ * @public
1704
+ */
1705
+ close(e, t) {
1706
+ if (this.readyState !== d.CLOSED) {
1707
+ if (this.readyState === d.CONNECTING) {
1708
+ b(this, this._req, "WebSocket was closed before the connection was established");
1709
+ return;
1710
+ }
1711
+ if (this.readyState === d.CLOSING) {
1712
+ this._closeFrameSent && (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end();
1713
+ return;
1714
+ }
1715
+ this._readyState = d.CLOSING, this._sender.close(e, t, !this._isServer, (r) => {
1716
+ r || (this._closeFrameSent = !0, (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end());
1717
+ }), this._closeTimer = setTimeout(
1718
+ this._socket.destroy.bind(this._socket),
1719
+ Ss
1720
+ );
1721
+ }
1722
+ }
1723
+ /**
1724
+ * Pause the socket.
1725
+ *
1726
+ * @public
1727
+ */
1728
+ pause() {
1729
+ this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !0, this._socket.pause());
1730
+ }
1731
+ /**
1732
+ * Send a ping.
1733
+ *
1734
+ * @param {*} [data] The data to send
1735
+ * @param {Boolean} [mask] Indicates whether or not to mask `data`
1736
+ * @param {Function} [cb] Callback which is executed when the ping is sent
1737
+ * @public
1738
+ */
1739
+ ping(e, t, r) {
1740
+ if (this.readyState === d.CONNECTING)
1741
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1742
+ if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1743
+ ve(this, e, r);
1744
+ return;
1745
+ }
1746
+ t === void 0 && (t = !this._isServer), this._sender.ping(e || Q, t, r);
1747
+ }
1748
+ /**
1749
+ * Send a pong.
1750
+ *
1751
+ * @param {*} [data] The data to send
1752
+ * @param {Boolean} [mask] Indicates whether or not to mask `data`
1753
+ * @param {Function} [cb] Callback which is executed when the pong is sent
1754
+ * @public
1755
+ */
1756
+ pong(e, t, r) {
1757
+ if (this.readyState === d.CONNECTING)
1758
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1759
+ if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1760
+ ve(this, e, r);
1761
+ return;
1762
+ }
1763
+ t === void 0 && (t = !this._isServer), this._sender.pong(e || Q, t, r);
1764
+ }
1765
+ /**
1766
+ * Resume the socket.
1767
+ *
1768
+ * @public
1769
+ */
1770
+ resume() {
1771
+ this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !1, this._receiver._writableState.needDrain || this._socket.resume());
1772
+ }
1773
+ /**
1774
+ * Send a data message.
1775
+ *
1776
+ * @param {*} data The message to send
1777
+ * @param {Object} [options] Options object
1778
+ * @param {Boolean} [options.binary] Specifies whether `data` is binary or
1779
+ * text
1780
+ * @param {Boolean} [options.compress] Specifies whether or not to compress
1781
+ * `data`
1782
+ * @param {Boolean} [options.fin=true] Specifies whether the fragment is the
1783
+ * last one
1784
+ * @param {Boolean} [options.mask] Specifies whether or not to mask `data`
1785
+ * @param {Function} [cb] Callback which is executed when data is written out
1786
+ * @public
1787
+ */
1788
+ send(e, t, r) {
1789
+ if (this.readyState === d.CONNECTING)
1790
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1791
+ if (typeof t == "function" && (r = t, t = {}), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1792
+ ve(this, e, r);
1793
+ return;
1794
+ }
1795
+ const i = {
1796
+ binary: typeof e != "string",
1797
+ mask: !this._isServer,
1798
+ compress: !0,
1799
+ fin: !0,
1800
+ ...t
1801
+ };
1802
+ this._extensions[T.extensionName] || (i.compress = !1), this._sender.send(e || Q, i, r);
1803
+ }
1804
+ /**
1805
+ * Forcibly close the connection.
1806
+ *
1807
+ * @public
1808
+ */
1809
+ terminate() {
1810
+ if (this.readyState !== d.CLOSED) {
1811
+ if (this.readyState === d.CONNECTING) {
1812
+ b(this, this._req, "WebSocket was closed before the connection was established");
1813
+ return;
1814
+ }
1815
+ this._socket && (this._readyState = d.CLOSING, this._socket.destroy());
1816
+ }
1817
+ }
1818
+ };
1819
+ Object.defineProperty(m, "CONNECTING", {
1820
+ enumerable: !0,
1821
+ value: O.indexOf("CONNECTING")
1822
+ });
1823
+ Object.defineProperty(m.prototype, "CONNECTING", {
1824
+ enumerable: !0,
1825
+ value: O.indexOf("CONNECTING")
1826
+ });
1827
+ Object.defineProperty(m, "OPEN", {
1828
+ enumerable: !0,
1829
+ value: O.indexOf("OPEN")
1830
+ });
1831
+ Object.defineProperty(m.prototype, "OPEN", {
1832
+ enumerable: !0,
1833
+ value: O.indexOf("OPEN")
1834
+ });
1835
+ Object.defineProperty(m, "CLOSING", {
1836
+ enumerable: !0,
1837
+ value: O.indexOf("CLOSING")
1838
+ });
1839
+ Object.defineProperty(m.prototype, "CLOSING", {
1840
+ enumerable: !0,
1841
+ value: O.indexOf("CLOSING")
1842
+ });
1843
+ Object.defineProperty(m, "CLOSED", {
1844
+ enumerable: !0,
1845
+ value: O.indexOf("CLOSED")
1846
+ });
1847
+ Object.defineProperty(m.prototype, "CLOSED", {
1848
+ enumerable: !0,
1849
+ value: O.indexOf("CLOSED")
1850
+ });
1851
+ [
1852
+ "binaryType",
1853
+ "bufferedAmount",
1854
+ "extensions",
1855
+ "isPaused",
1856
+ "protocol",
1857
+ "readyState",
1858
+ "url"
1859
+ ].forEach((s) => {
1860
+ Object.defineProperty(m.prototype, s, { enumerable: !0 });
1861
+ });
1862
+ ["open", "error", "close", "message"].forEach((s) => {
1863
+ Object.defineProperty(m.prototype, `on${s}`, {
1864
+ enumerable: !0,
1865
+ get() {
1866
+ for (const e of this.listeners(s))
1867
+ if (e[ge])
1868
+ return e[ds];
1869
+ return null;
1870
+ },
1871
+ set(e) {
1872
+ for (const t of this.listeners(s))
1873
+ if (t[ge]) {
1874
+ this.removeListener(s, t);
1875
+ break;
1876
+ }
1877
+ typeof e == "function" && this.addEventListener(s, e, {
1878
+ [ge]: !0
1879
+ });
1880
+ }
1881
+ });
1882
+ });
1883
+ m.prototype.addEventListener = ps;
1884
+ m.prototype.removeEventListener = ms;
1885
+ var ft = m;
1886
+ function ht(s, e, t, r) {
1887
+ const i = {
1888
+ protocolVersion: ye[1],
1889
+ maxPayload: 104857600,
1890
+ skipUTF8Validation: !1,
1891
+ perMessageDeflate: !0,
1892
+ followRedirects: !1,
1893
+ maxRedirects: 10,
1894
+ ...r,
1895
+ createConnection: void 0,
1896
+ socketPath: void 0,
1897
+ hostname: void 0,
1898
+ protocol: void 0,
1899
+ timeout: void 0,
1900
+ method: "GET",
1901
+ host: void 0,
1902
+ path: void 0,
1903
+ port: void 0
1904
+ };
1905
+ if (!ye.includes(i.protocolVersion))
1906
+ throw new RangeError(
1907
+ `Unsupported protocol version: ${i.protocolVersion} (supported versions: ${ye.join(", ")})`
1908
+ );
1909
+ let n;
1910
+ if (e instanceof me)
1911
+ n = e, s._url = e.href;
1912
+ else {
1913
+ try {
1914
+ n = new me(e);
1915
+ } catch {
1916
+ throw new SyntaxError(`Invalid URL: ${e}`);
1917
+ }
1918
+ s._url = e;
1919
+ }
1920
+ const o = n.protocol === "wss:", l = n.protocol === "ws+unix:";
1921
+ let f;
1922
+ if (n.protocol !== "ws:" && !o && !l ? f = `The URL's protocol must be one of "ws:", "wss:", or "ws+unix:"` : l && !n.pathname ? f = "The URL's pathname is empty" : n.hash && (f = "The URL contains a fragment identifier"), f) {
1923
+ const u = new SyntaxError(f);
1924
+ if (s._redirects === 0)
1925
+ throw u;
1926
+ ee(s, u);
1927
+ return;
1928
+ }
1929
+ const a = o ? 443 : 80, c = ls(16).toString("base64"), h = o ? ns.request : os.request, p = /* @__PURE__ */ new Set();
1930
+ let v;
1931
+ if (i.createConnection = o ? xs : bs, i.defaultPort = i.defaultPort || a, i.port = n.port || a, i.host = n.hostname.startsWith("[") ? n.hostname.slice(1, -1) : n.hostname, i.headers = {
1932
+ ...i.headers,
1933
+ "Sec-WebSocket-Version": i.protocolVersion,
1934
+ "Sec-WebSocket-Key": c,
1935
+ Connection: "Upgrade",
1936
+ Upgrade: "websocket"
1937
+ }, i.path = n.pathname + n.search, i.timeout = i.handshakeTimeout, i.perMessageDeflate && (v = new T(
1938
+ i.perMessageDeflate !== !0 ? i.perMessageDeflate : {},
1939
+ !1,
1940
+ i.maxPayload
1941
+ ), i.headers["Sec-WebSocket-Extensions"] = gs({
1942
+ [T.extensionName]: v.offer()
1943
+ })), t.length) {
1944
+ for (const u of t) {
1945
+ if (typeof u != "string" || !Es.test(u) || p.has(u))
1946
+ throw new SyntaxError(
1947
+ "An invalid or duplicated subprotocol was specified"
1948
+ );
1949
+ p.add(u);
1950
+ }
1951
+ i.headers["Sec-WebSocket-Protocol"] = t.join(",");
1952
+ }
1953
+ if (i.origin && (i.protocolVersion < 13 ? i.headers["Sec-WebSocket-Origin"] = i.origin : i.headers.Origin = i.origin), (n.username || n.password) && (i.auth = `${n.username}:${n.password}`), l) {
1954
+ const u = i.path.split(":");
1955
+ i.socketPath = u[0], i.path = u[1];
1956
+ }
1957
+ let _;
1958
+ if (i.followRedirects) {
1959
+ if (s._redirects === 0) {
1960
+ s._originalIpc = l, s._originalSecure = o, s._originalHostOrSocketPath = l ? i.socketPath : n.host;
1961
+ const u = r && r.headers;
1962
+ if (r = { ...r, headers: {} }, u)
1963
+ for (const [E, $] of Object.entries(u))
1964
+ r.headers[E.toLowerCase()] = $;
1965
+ } else if (s.listenerCount("redirect") === 0) {
1966
+ const u = l ? s._originalIpc ? i.socketPath === s._originalHostOrSocketPath : !1 : s._originalIpc ? !1 : n.host === s._originalHostOrSocketPath;
1967
+ (!u || s._originalSecure && !o) && (delete i.headers.authorization, delete i.headers.cookie, u || delete i.headers.host, i.auth = void 0);
1968
+ }
1969
+ i.auth && !r.headers.authorization && (r.headers.authorization = "Basic " + Buffer.from(i.auth).toString("base64")), _ = s._req = h(i), s._redirects && s.emit("redirect", s.url, _);
1970
+ } else
1971
+ _ = s._req = h(i);
1972
+ i.timeout && _.on("timeout", () => {
1973
+ b(s, _, "Opening handshake has timed out");
1974
+ }), _.on("error", (u) => {
1975
+ _ === null || _[lt] || (_ = s._req = null, ee(s, u));
1976
+ }), _.on("response", (u) => {
1977
+ const E = u.headers.location, $ = u.statusCode;
1978
+ if (E && i.followRedirects && $ >= 300 && $ < 400) {
1979
+ if (++s._redirects > i.maxRedirects) {
1980
+ b(s, _, "Maximum redirects exceeded");
1981
+ return;
1982
+ }
1983
+ _.abort();
1984
+ let q;
1985
+ try {
1986
+ q = new me(E, e);
1987
+ } catch {
1988
+ const L = new SyntaxError(`Invalid URL: ${E}`);
1989
+ ee(s, L);
1990
+ return;
1991
+ }
1992
+ ht(s, q, t, r);
1993
+ } else
1994
+ s.emit("unexpected-response", _, u) || b(
1995
+ s,
1996
+ _,
1997
+ `Unexpected server response: ${u.statusCode}`
1998
+ );
1999
+ }), _.on("upgrade", (u, E, $) => {
2000
+ if (s.emit("upgrade", u), s.readyState !== m.CONNECTING)
2001
+ return;
2002
+ if (_ = s._req = null, u.headers.upgrade.toLowerCase() !== "websocket") {
2003
+ b(s, E, "Invalid Upgrade header");
2004
+ return;
2005
+ }
2006
+ const q = fs("sha1").update(c + us).digest("base64");
2007
+ if (u.headers["sec-websocket-accept"] !== q) {
2008
+ b(s, E, "Invalid Sec-WebSocket-Accept header");
2009
+ return;
2010
+ }
2011
+ const D = u.headers["sec-websocket-protocol"];
2012
+ let L;
2013
+ if (D !== void 0 ? p.size ? p.has(D) || (L = "Server sent an invalid subprotocol") : L = "Server sent a subprotocol but none was requested" : p.size && (L = "Server sent no subprotocol"), L) {
2014
+ b(s, E, L);
2015
+ return;
2016
+ }
2017
+ D && (s._protocol = D);
2018
+ const ke = u.headers["sec-websocket-extensions"];
2019
+ if (ke !== void 0) {
2020
+ if (!v) {
2021
+ b(s, E, "Server sent a Sec-WebSocket-Extensions header but no extension was requested");
2022
+ return;
2023
+ }
2024
+ let he;
2025
+ try {
2026
+ he = ys(ke);
2027
+ } catch {
2028
+ b(s, E, "Invalid Sec-WebSocket-Extensions header");
2029
+ return;
2030
+ }
2031
+ const we = Object.keys(he);
2032
+ if (we.length !== 1 || we[0] !== T.extensionName) {
2033
+ b(s, E, "Server indicated an extension that was not requested");
2034
+ return;
2035
+ }
2036
+ try {
2037
+ v.accept(he[T.extensionName]);
2038
+ } catch {
2039
+ b(s, E, "Invalid Sec-WebSocket-Extensions header");
2040
+ return;
2041
+ }
2042
+ s._extensions[T.extensionName] = v;
2043
+ }
2044
+ s.setSocket(E, $, {
2045
+ generateMask: i.generateMask,
2046
+ maxPayload: i.maxPayload,
2047
+ skipUTF8Validation: i.skipUTF8Validation
2048
+ });
2049
+ }), i.finishRequest ? i.finishRequest(_, s) : _.end();
2050
+ }
2051
+ function ee(s, e) {
2052
+ s._readyState = m.CLOSING, s.emit("error", e), s.emitClose();
2053
+ }
2054
+ function bs(s) {
2055
+ return s.path = s.socketPath, ot.connect(s);
2056
+ }
2057
+ function xs(s) {
2058
+ return s.path = void 0, !s.servername && s.servername !== "" && (s.servername = ot.isIP(s.host) ? "" : s.host), as.connect(s);
2059
+ }
2060
+ function b(s, e, t) {
2061
+ s._readyState = m.CLOSING;
2062
+ const r = new Error(t);
2063
+ Error.captureStackTrace(r, b), e.setHeader ? (e[lt] = !0, e.abort(), e.socket && !e.socket.destroyed && e.socket.destroy(), process.nextTick(ee, s, r)) : (e.destroy(r), e.once("error", s.emit.bind(s, "error")), e.once("close", s.emitClose.bind(s)));
2064
+ }
2065
+ function ve(s, e, t) {
2066
+ if (e) {
2067
+ const r = vs(e).length;
2068
+ s._socket ? s._sender._bufferedBytes += r : s._bufferedAmount += r;
2069
+ }
2070
+ if (t) {
2071
+ const r = new Error(
2072
+ `WebSocket is not open: readyState ${s.readyState} (${O[s.readyState]})`
2073
+ );
2074
+ process.nextTick(t, r);
2075
+ }
2076
+ }
2077
+ function ks(s, e) {
2078
+ const t = this[y];
2079
+ t._closeFrameReceived = !0, t._closeMessage = e, t._closeCode = s, t._socket[y] !== void 0 && (t._socket.removeListener("data", fe), process.nextTick(ct, t._socket), s === 1005 ? t.close() : t.close(s, e));
2080
+ }
2081
+ function ws() {
2082
+ const s = this[y];
2083
+ s.isPaused || s._socket.resume();
2084
+ }
2085
+ function Os(s) {
2086
+ const e = this[y];
2087
+ e._socket[y] !== void 0 && (e._socket.removeListener("data", fe), process.nextTick(ct, e._socket), e.close(s[_s])), e.emit("error", s);
2088
+ }
2089
+ function Ye() {
2090
+ this[y].emitClose();
2091
+ }
2092
+ function Cs(s, e) {
2093
+ this[y].emit("message", s, e);
2094
+ }
2095
+ function Ts(s) {
2096
+ const e = this[y];
2097
+ e.pong(s, !e._isServer, at), e.emit("ping", s);
2098
+ }
2099
+ function Ls(s) {
2100
+ this[y].emit("pong", s);
2101
+ }
2102
+ function ct(s) {
2103
+ s.resume();
2104
+ }
2105
+ function ut() {
2106
+ const s = this[y];
2107
+ this.removeListener("close", ut), this.removeListener("data", fe), this.removeListener("end", dt), s._readyState = m.CLOSING;
2108
+ let e;
2109
+ !this._readableState.endEmitted && !s._closeFrameReceived && !s._receiver._writableState.errorEmitted && (e = s._socket.read()) !== null && s._receiver.write(e), s._receiver.end(), this[y] = void 0, clearTimeout(s._closeTimer), s._receiver._writableState.finished || s._receiver._writableState.errorEmitted ? s.emitClose() : (s._receiver.on("error", Ye), s._receiver.on("finish", Ye));
2110
+ }
2111
+ function fe(s) {
2112
+ this[y]._receiver.write(s) || this.pause();
2113
+ }
2114
+ function dt() {
2115
+ const s = this[y];
2116
+ s._readyState = m.CLOSING, s._receiver.end(), this.end();
2117
+ }
2118
+ function _t() {
2119
+ const s = this[y];
2120
+ this.removeListener("error", _t), this.on("error", at), s && (s._readyState = m.CLOSING, this.destroy());
2121
+ }
2122
+ const Xs = /* @__PURE__ */ z(ft), { tokenChars: Ns } = ae;
2123
+ function Ps(s) {
2124
+ const e = /* @__PURE__ */ new Set();
2125
+ let t = -1, r = -1, i = 0;
2126
+ for (i; i < s.length; i++) {
2127
+ const o = s.charCodeAt(i);
2128
+ if (r === -1 && Ns[o] === 1)
2129
+ t === -1 && (t = i);
2130
+ else if (i !== 0 && (o === 32 || o === 9))
2131
+ r === -1 && t !== -1 && (r = i);
2132
+ else if (o === 44) {
2133
+ if (t === -1)
2134
+ throw new SyntaxError(`Unexpected character at index ${i}`);
2135
+ r === -1 && (r = i);
2136
+ const l = s.slice(t, r);
2137
+ if (e.has(l))
2138
+ throw new SyntaxError(`The "${l}" subprotocol is duplicated`);
2139
+ e.add(l), t = r = -1;
2140
+ } else
2141
+ throw new SyntaxError(`Unexpected character at index ${i}`);
2142
+ }
2143
+ if (t === -1 || r !== -1)
2144
+ throw new SyntaxError("Unexpected end of input");
2145
+ const n = s.slice(t, i);
2146
+ if (e.has(n))
2147
+ throw new SyntaxError(`The "${n}" subprotocol is duplicated`);
2148
+ return e.add(n), e;
2149
+ }
2150
+ var Rs = { parse: Ps };
2151
+ const Us = S, ie = S, { createHash: Bs } = S, qe = nt, N = oe, $s = Rs, Ms = ft, { GUID: Is, kWebSocket: Ds } = U, Ws = /^[+/0-9A-Za-z]{22}==$/, Ke = 0, Xe = 1, pt = 2;
2152
+ class As extends Us {
2153
+ /**
2154
+ * Create a `WebSocketServer` instance.
2155
+ *
2156
+ * @param {Object} options Configuration options
2157
+ * @param {Number} [options.backlog=511] The maximum length of the queue of
2158
+ * pending connections
2159
+ * @param {Boolean} [options.clientTracking=true] Specifies whether or not to
2160
+ * track clients
2161
+ * @param {Function} [options.handleProtocols] A hook to handle protocols
2162
+ * @param {String} [options.host] The hostname where to bind the server
2163
+ * @param {Number} [options.maxPayload=104857600] The maximum allowed message
2164
+ * size
2165
+ * @param {Boolean} [options.noServer=false] Enable no server mode
2166
+ * @param {String} [options.path] Accept only connections matching this path
2167
+ * @param {(Boolean|Object)} [options.perMessageDeflate=false] Enable/disable
2168
+ * permessage-deflate
2169
+ * @param {Number} [options.port] The port where to bind the server
2170
+ * @param {(http.Server|https.Server)} [options.server] A pre-created HTTP/S
2171
+ * server to use
2172
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
2173
+ * not to skip UTF-8 validation for text and close messages
2174
+ * @param {Function} [options.verifyClient] A hook to reject connections
2175
+ * @param {Function} [options.WebSocket=WebSocket] Specifies the `WebSocket`
2176
+ * class to use. It must be the `WebSocket` class or class that extends it
2177
+ * @param {Function} [callback] A listener for the `listening` event
2178
+ */
2179
+ constructor(e, t) {
2180
+ if (super(), e = {
2181
+ maxPayload: 100 * 1024 * 1024,
2182
+ skipUTF8Validation: !1,
2183
+ perMessageDeflate: !1,
2184
+ handleProtocols: null,
2185
+ clientTracking: !0,
2186
+ verifyClient: null,
2187
+ noServer: !1,
2188
+ backlog: null,
2189
+ // use default (511 as implemented in net.js)
2190
+ server: null,
2191
+ host: null,
2192
+ path: null,
2193
+ port: null,
2194
+ WebSocket: Ms,
2195
+ ...e
2196
+ }, e.port == null && !e.server && !e.noServer || e.port != null && (e.server || e.noServer) || e.server && e.noServer)
2197
+ throw new TypeError(
2198
+ 'One and only one of the "port", "server", or "noServer" options must be specified'
2199
+ );
2200
+ if (e.port != null ? (this._server = ie.createServer((r, i) => {
2201
+ const n = ie.STATUS_CODES[426];
2202
+ i.writeHead(426, {
2203
+ "Content-Length": n.length,
2204
+ "Content-Type": "text/plain"
2205
+ }), i.end(n);
2206
+ }), this._server.listen(
2207
+ e.port,
2208
+ e.host,
2209
+ e.backlog,
2210
+ t
2211
+ )) : e.server && (this._server = e.server), this._server) {
2212
+ const r = this.emit.bind(this, "connection");
2213
+ this._removeListeners = js(this._server, {
2214
+ listening: this.emit.bind(this, "listening"),
2215
+ error: this.emit.bind(this, "error"),
2216
+ upgrade: (i, n, o) => {
2217
+ this.handleUpgrade(i, n, o, r);
2218
+ }
2219
+ });
2220
+ }
2221
+ e.perMessageDeflate === !0 && (e.perMessageDeflate = {}), e.clientTracking && (this.clients = /* @__PURE__ */ new Set(), this._shouldEmitClose = !1), this.options = e, this._state = Ke;
2222
+ }
2223
+ /**
2224
+ * Returns the bound address, the address family name, and port of the server
2225
+ * as reported by the operating system if listening on an IP socket.
2226
+ * If the server is listening on a pipe or UNIX domain socket, the name is
2227
+ * returned as a string.
2228
+ *
2229
+ * @return {(Object|String|null)} The address of the server
2230
+ * @public
2231
+ */
2232
+ address() {
2233
+ if (this.options.noServer)
2234
+ throw new Error('The server is operating in "noServer" mode');
2235
+ return this._server ? this._server.address() : null;
2236
+ }
2237
+ /**
2238
+ * Stop the server from accepting new connections and emit the `'close'` event
2239
+ * when all existing connections are closed.
2240
+ *
2241
+ * @param {Function} [cb] A one-time listener for the `'close'` event
2242
+ * @public
2243
+ */
2244
+ close(e) {
2245
+ if (this._state === pt) {
2246
+ e && this.once("close", () => {
2247
+ e(new Error("The server is not running"));
2248
+ }), process.nextTick(G, this);
2249
+ return;
2250
+ }
2251
+ if (e && this.once("close", e), this._state !== Xe)
2252
+ if (this._state = Xe, this.options.noServer || this.options.server)
2253
+ this._server && (this._removeListeners(), this._removeListeners = this._server = null), this.clients ? this.clients.size ? this._shouldEmitClose = !0 : process.nextTick(G, this) : process.nextTick(G, this);
2254
+ else {
2255
+ const t = this._server;
2256
+ this._removeListeners(), this._removeListeners = this._server = null, t.close(() => {
2257
+ G(this);
2258
+ });
2259
+ }
2260
+ }
2261
+ /**
2262
+ * See if a given request should be handled by this server instance.
2263
+ *
2264
+ * @param {http.IncomingMessage} req Request object to inspect
2265
+ * @return {Boolean} `true` if the request is valid, else `false`
2266
+ * @public
2267
+ */
2268
+ shouldHandle(e) {
2269
+ if (this.options.path) {
2270
+ const t = e.url.indexOf("?");
2271
+ if ((t !== -1 ? e.url.slice(0, t) : e.url) !== this.options.path)
2272
+ return !1;
2273
+ }
2274
+ return !0;
2275
+ }
2276
+ /**
2277
+ * Handle a HTTP Upgrade request.
2278
+ *
2279
+ * @param {http.IncomingMessage} req The request object
2280
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
2281
+ * server and client
2282
+ * @param {Buffer} head The first packet of the upgraded stream
2283
+ * @param {Function} cb Callback
2284
+ * @public
2285
+ */
2286
+ handleUpgrade(e, t, r, i) {
2287
+ t.on("error", Ze);
2288
+ const n = e.headers["sec-websocket-key"], o = +e.headers["sec-websocket-version"];
2289
+ if (e.method !== "GET") {
2290
+ R(this, e, t, 405, "Invalid HTTP method");
2291
+ return;
2292
+ }
2293
+ if (e.headers.upgrade.toLowerCase() !== "websocket") {
2294
+ R(this, e, t, 400, "Invalid Upgrade header");
2295
+ return;
2296
+ }
2297
+ if (!n || !Ws.test(n)) {
2298
+ R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Key header");
2299
+ return;
2300
+ }
2301
+ if (o !== 8 && o !== 13) {
2302
+ R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Version header");
2303
+ return;
2304
+ }
2305
+ if (!this.shouldHandle(e)) {
2306
+ H(t, 400);
2307
+ return;
2308
+ }
2309
+ const l = e.headers["sec-websocket-protocol"];
2310
+ let f = /* @__PURE__ */ new Set();
2311
+ if (l !== void 0)
2312
+ try {
2313
+ f = $s.parse(l);
2314
+ } catch {
2315
+ R(this, e, t, 400, "Invalid Sec-WebSocket-Protocol header");
2316
+ return;
2317
+ }
2318
+ const a = e.headers["sec-websocket-extensions"], c = {};
2319
+ if (this.options.perMessageDeflate && a !== void 0) {
2320
+ const h = new N(
2321
+ this.options.perMessageDeflate,
2322
+ !0,
2323
+ this.options.maxPayload
2324
+ );
2325
+ try {
2326
+ const p = qe.parse(a);
2327
+ p[N.extensionName] && (h.accept(p[N.extensionName]), c[N.extensionName] = h);
2328
+ } catch {
2329
+ R(this, e, t, 400, "Invalid or unacceptable Sec-WebSocket-Extensions header");
2330
+ return;
2331
+ }
2332
+ }
2333
+ if (this.options.verifyClient) {
2334
+ const h = {
2335
+ origin: e.headers[`${o === 8 ? "sec-websocket-origin" : "origin"}`],
2336
+ secure: !!(e.socket.authorized || e.socket.encrypted),
2337
+ req: e
2338
+ };
2339
+ if (this.options.verifyClient.length === 2) {
2340
+ this.options.verifyClient(h, (p, v, _, u) => {
2341
+ if (!p)
2342
+ return H(t, v || 401, _, u);
2343
+ this.completeUpgrade(
2344
+ c,
2345
+ n,
2346
+ f,
2347
+ e,
2348
+ t,
2349
+ r,
2350
+ i
2351
+ );
2352
+ });
2353
+ return;
2354
+ }
2355
+ if (!this.options.verifyClient(h))
2356
+ return H(t, 401);
2357
+ }
2358
+ this.completeUpgrade(c, n, f, e, t, r, i);
2359
+ }
2360
+ /**
2361
+ * Upgrade the connection to WebSocket.
2362
+ *
2363
+ * @param {Object} extensions The accepted extensions
2364
+ * @param {String} key The value of the `Sec-WebSocket-Key` header
2365
+ * @param {Set} protocols The subprotocols
2366
+ * @param {http.IncomingMessage} req The request object
2367
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
2368
+ * server and client
2369
+ * @param {Buffer} head The first packet of the upgraded stream
2370
+ * @param {Function} cb Callback
2371
+ * @throws {Error} If called more than once with the same socket
2372
+ * @private
2373
+ */
2374
+ completeUpgrade(e, t, r, i, n, o, l) {
2375
+ if (!n.readable || !n.writable)
2376
+ return n.destroy();
2377
+ if (n[Ds])
2378
+ throw new Error(
2379
+ "server.handleUpgrade() was called more than once with the same socket, possibly due to a misconfiguration"
2380
+ );
2381
+ if (this._state > Ke)
2382
+ return H(n, 503);
2383
+ const a = [
2384
+ "HTTP/1.1 101 Switching Protocols",
2385
+ "Upgrade: websocket",
2386
+ "Connection: Upgrade",
2387
+ `Sec-WebSocket-Accept: ${Bs("sha1").update(t + Is).digest("base64")}`
2388
+ ], c = new this.options.WebSocket(null);
2389
+ if (r.size) {
2390
+ const h = this.options.handleProtocols ? this.options.handleProtocols(r, i) : r.values().next().value;
2391
+ h && (a.push(`Sec-WebSocket-Protocol: ${h}`), c._protocol = h);
2392
+ }
2393
+ if (e[N.extensionName]) {
2394
+ const h = e[N.extensionName].params, p = qe.format({
2395
+ [N.extensionName]: [h]
2396
+ });
2397
+ a.push(`Sec-WebSocket-Extensions: ${p}`), c._extensions = e;
2398
+ }
2399
+ this.emit("headers", a, i), n.write(a.concat(`\r
2400
+ `).join(`\r
2401
+ `)), n.removeListener("error", Ze), c.setSocket(n, o, {
2402
+ maxPayload: this.options.maxPayload,
2403
+ skipUTF8Validation: this.options.skipUTF8Validation
2404
+ }), this.clients && (this.clients.add(c), c.on("close", () => {
2405
+ this.clients.delete(c), this._shouldEmitClose && !this.clients.size && process.nextTick(G, this);
2406
+ })), l(c, i);
2407
+ }
2408
+ }
2409
+ var Fs = As;
2410
+ function js(s, e) {
2411
+ for (const t of Object.keys(e))
2412
+ s.on(t, e[t]);
2413
+ return function() {
2414
+ for (const r of Object.keys(e))
2415
+ s.removeListener(r, e[r]);
2416
+ };
2417
+ }
2418
+ function G(s) {
2419
+ s._state = pt, s.emit("close");
2420
+ }
2421
+ function Ze() {
2422
+ this.destroy();
2423
+ }
2424
+ function H(s, e, t, r) {
2425
+ t = t || ie.STATUS_CODES[e], r = {
2426
+ Connection: "close",
2427
+ "Content-Type": "text/html",
2428
+ "Content-Length": Buffer.byteLength(t),
2429
+ ...r
2430
+ }, s.once("finish", s.destroy), s.end(
2431
+ `HTTP/1.1 ${e} ${ie.STATUS_CODES[e]}\r
2432
+ ` + Object.keys(r).map((i) => `${i}: ${r[i]}`).join(`\r
2433
+ `) + `\r
2434
+ \r
2435
+ ` + t
2436
+ );
2437
+ }
2438
+ function R(s, e, t, r, i) {
2439
+ if (s.listenerCount("wsClientError")) {
2440
+ const n = new Error(i);
2441
+ Error.captureStackTrace(n, R), s.emit("wsClientError", n, t, e);
2442
+ } else
2443
+ H(t, r, i);
2444
+ }
2445
+ const Zs = /* @__PURE__ */ z(Fs);
2446
+ export {
2447
+ qs as Receiver,
2448
+ Ks as Sender,
2449
+ Xs as WebSocket,
2450
+ Zs as WebSocketServer,
2451
+ Vs as createWebSocketStream,
2452
+ Xs as default
2453
+ };
gradio_dualvision/gradio_patches/templates/example/index.js ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const {
2
+ SvelteComponent: y,
3
+ append: m,
4
+ attr: n,
5
+ detach: b,
6
+ element: u,
7
+ init: k,
8
+ insert: p,
9
+ noop: w,
10
+ safe_not_equal: h,
11
+ space: o,
12
+ src_url_equal: v,
13
+ toggle_class: f
14
+ } = window.__gradio__svelte__internal;
15
+ function q(a) {
16
+ let e, l, _, g, i, c, s, d;
17
+ return {
18
+ c() {
19
+ e = u("div"), l = u("img"), g = o(), i = u("img"), s = o(), d = u("span"), v(l.src, _ = /*samples_dir*/
20
+ a[1] + /*value*/
21
+ a[0][0]) || n(l, "src", _), n(l, "class", "svelte-l4wpk0"), v(i.src, c = /*samples_dir*/
22
+ a[1] + /*value*/
23
+ a[0][1]) || n(i, "src", c), n(i, "class", "svelte-l4wpk0"), n(d, "class", "svelte-l4wpk0"), n(e, "class", "wrap svelte-l4wpk0"), f(
24
+ e,
25
+ "table",
26
+ /*type*/
27
+ a[2] === "table"
28
+ ), f(
29
+ e,
30
+ "gallery",
31
+ /*type*/
32
+ a[2] === "gallery"
33
+ ), f(
34
+ e,
35
+ "selected",
36
+ /*selected*/
37
+ a[3]
38
+ );
39
+ },
40
+ m(t, r) {
41
+ p(t, e, r), m(e, l), m(e, g), m(e, i), m(e, s), m(e, d);
42
+ },
43
+ p(t, [r]) {
44
+ r & /*samples_dir, value*/
45
+ 3 && !v(l.src, _ = /*samples_dir*/
46
+ t[1] + /*value*/
47
+ t[0][0]) && n(l, "src", _), r & /*samples_dir, value*/
48
+ 3 && !v(i.src, c = /*samples_dir*/
49
+ t[1] + /*value*/
50
+ t[0][1]) && n(i, "src", c), r & /*type*/
51
+ 4 && f(
52
+ e,
53
+ "table",
54
+ /*type*/
55
+ t[2] === "table"
56
+ ), r & /*type*/
57
+ 4 && f(
58
+ e,
59
+ "gallery",
60
+ /*type*/
61
+ t[2] === "gallery"
62
+ ), r & /*selected*/
63
+ 8 && f(
64
+ e,
65
+ "selected",
66
+ /*selected*/
67
+ t[3]
68
+ );
69
+ },
70
+ i: w,
71
+ o: w,
72
+ d(t) {
73
+ t && b(e);
74
+ }
75
+ };
76
+ }
77
+ function I(a, e, l) {
78
+ let { value: _ } = e, { samples_dir: g } = e, { type: i } = e, { selected: c = !1 } = e;
79
+ return a.$$set = (s) => {
80
+ "value" in s && l(0, _ = s.value), "samples_dir" in s && l(1, g = s.samples_dir), "type" in s && l(2, i = s.type), "selected" in s && l(3, c = s.selected);
81
+ }, [_, g, i, c];
82
+ }
83
+ class C extends y {
84
+ constructor(e) {
85
+ super(), k(this, e, I, q, h, {
86
+ value: 0,
87
+ samples_dir: 1,
88
+ type: 2,
89
+ selected: 3
90
+ });
91
+ }
92
+ }
93
+ export {
94
+ C as default
95
+ };
gradio_dualvision/gradio_patches/templates/example/style.css ADDED
@@ -0,0 +1 @@
 
 
1
+ .wrap.svelte-l4wpk0.svelte-l4wpk0{position:relative;height:var(--size-64);width:var(--size-40);overflow:hidden;border-radius:var(--radius-lg)}img.svelte-l4wpk0.svelte-l4wpk0{height:var(--size-64);width:var(--size-40);position:absolute;object-fit:cover}.wrap.selected.svelte-l4wpk0.svelte-l4wpk0{border-color:var(--color-accent)}.wrap.svelte-l4wpk0 img.svelte-l4wpk0:first-child{clip-path:inset(0 50% 0 0%)}.wrap.svelte-l4wpk0 img.svelte-l4wpk0:nth-of-type(2){clip-path:inset(0 0 0 50%)}span.svelte-l4wpk0.svelte-l4wpk0{position:absolute;top:0;left:calc(50% - .75px);height:var(--size-64);width:1.5px;background:var(--border-color-primary)}.table.svelte-l4wpk0.svelte-l4wpk0{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg)}.gallery.svelte-l4wpk0.svelte-l4wpk0{border:2px solid var(--border-color-primary);object-fit:cover}
gradio_dualvision/version.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+
25
+ __version__ = "0.1.0"
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .hyper import Hypernetwork
2
+ from .thera import build_thera
model/convnext.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ from jaxtyping import Array, ArrayLike
3
+
4
+
5
+ class ConvNeXtBlock(nn.Module):
6
+ """ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
7
+
8
+ https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
9
+ """
10
+ n_dims: int = 64
11
+ kernel_size: int = 3 # 7 in the paper's version
12
+ group_features: bool = False
13
+
14
+ def setup(self) -> None:
15
+ self.residual = nn.Sequential([
16
+ nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
17
+ feature_group_count=self.n_dims if self.group_features else 1),
18
+ nn.LayerNorm(),
19
+ nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
20
+ nn.gelu,
21
+ nn.Conv(self.n_dims, kernel_size=(1, 1)),
22
+ ])
23
+
24
+ def __call__(self, x: ArrayLike) -> Array:
25
+ return x + self.residual(x)
26
+
27
+
28
+ class Projection(nn.Module):
29
+ n_dims: int
30
+
31
+ @nn.compact
32
+ def __call__(self, x: ArrayLike) -> Array:
33
+ x = nn.LayerNorm()(x)
34
+ x = nn.Conv(self.n_dims, (1, 1))(x)
35
+ return x
36
+
37
+
38
+ class ConvNeXt(nn.Module):
39
+ block_defs: list[tuple]
40
+
41
+ def setup(self) -> None:
42
+ layers = []
43
+ current_size = self.block_defs[0][0]
44
+ for block_def in self.block_defs:
45
+ if block_def[0] != current_size:
46
+ layers.append(Projection(block_def[0]))
47
+ layers.append(ConvNeXtBlock(*block_def))
48
+ current_size = block_def[0]
49
+ self.layers = layers
50
+
51
+ def __call__(self, x: ArrayLike, _: bool) -> Array:
52
+ for layer in self.layers:
53
+ x = layer(x)
54
+ return x
55
+
model/edsr.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/isaaccorley/jax-enhance
2
+
3
+ from functools import partial
4
+ from typing import Any, Sequence, Callable
5
+
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ from flax.core.frozen_dict import freeze
9
+ import einops
10
+
11
+
12
+ class PixelShuffle(nn.Module):
13
+ scale_factor: int
14
+
15
+ def setup(self):
16
+ self.layer = partial(
17
+ einops.rearrange,
18
+ pattern="b h w (c h2 w2) -> b (h h2) (w w2) c",
19
+ h2=self.scale_factor,
20
+ w2=self.scale_factor
21
+ )
22
+
23
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
24
+ return self.layer(x)
25
+
26
+
27
+ class ResidualBlock(nn.Module):
28
+ channels: int
29
+ kernel_size: Sequence[int]
30
+ res_scale: float
31
+ activation: Callable
32
+ dtype: Any = jnp.float32
33
+
34
+ def setup(self):
35
+ self.body = nn.Sequential([
36
+ nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
37
+ self.activation,
38
+ nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
39
+ ])
40
+
41
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
42
+ return x + self.body(x)
43
+
44
+
45
+ class UpsampleBlock(nn.Module):
46
+ num_upsamples: int
47
+ channels: int
48
+ kernel_size: Sequence[int]
49
+ dtype: Any = jnp.float32
50
+
51
+ def setup(self):
52
+ layers = []
53
+ for _ in range(self.num_upsamples):
54
+ layers.extend([
55
+ nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype),
56
+ PixelShuffle(scale_factor=2),
57
+ ])
58
+ self.layers = layers
59
+
60
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
61
+ for layer in self.layers:
62
+ x = layer(x)
63
+ return x
64
+
65
+
66
+ class EDSR(nn.Module):
67
+ """Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf"""
68
+ scale_factor: int
69
+ channels: int = 3
70
+ num_blocks: int = 32
71
+ num_feats: int = 256
72
+ dtype: Any = jnp.float32
73
+
74
+ def setup(self):
75
+ # pre res blocks layer
76
+ self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)])
77
+
78
+ # res blocks
79
+ res_blocks = [
80
+ ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype)
81
+ for i in range(self.num_blocks)
82
+ ]
83
+ res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype))
84
+ self.body = nn.Sequential(res_blocks)
85
+
86
+ def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray:
87
+ x = self.head(x)
88
+ x = x + self.body(x)
89
+ return x
90
+
91
+
92
+ def convert_edsr_checkpoint(torch_dict, no_upsampling=True):
93
+ def convert(in_dict):
94
+ top_keys = set([k.split('.')[0] for k in in_dict.keys()])
95
+ leaves = set([k for k in in_dict.keys() if '.' not in k])
96
+
97
+ # convert leaves
98
+ out_dict = {}
99
+ for l in leaves:
100
+ if l == 'weight':
101
+ out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0))
102
+ elif l == 'bias':
103
+ out_dict[l] = jnp.asarray(in_dict[l])
104
+ else:
105
+ out_dict[l] = in_dict[l]
106
+
107
+ for top_key in top_keys.difference(leaves):
108
+ new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key
109
+ out_dict[new_top_key] = convert(
110
+ {k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)})
111
+ return out_dict
112
+
113
+ converted = convert(torch_dict)
114
+
115
+ # remove unwanted keys
116
+ if no_upsampling:
117
+ del converted['tail']
118
+
119
+ for k in ('add_mean', 'sub_mean'):
120
+ del converted[k]
121
+
122
+ return freeze(converted)
model/hyper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+ from jaxtyping import Array, ArrayLike, PyTreeDef
7
+ import numpy as np
8
+
9
+ from utils import interpolate_grid
10
+
11
+
12
+ class Hypernetwork(nn.Module):
13
+ encoder: nn.Module
14
+ refine: nn.Module
15
+ output_params_shape: list[tuple] # e.g. [(16,), (32, 32), ...]
16
+ tree_def: PyTreeDef # used to reconstruct the parameter sets
17
+
18
+ def setup(self):
19
+ # one layer 1x1 conv to calculate field params, as in SIREN paper
20
+ output_size = sum(math.prod(s) for s in self.output_params_shape)
21
+ self.out_conv = nn.Conv(output_size, kernel_size=(1, 1), use_bias=True)
22
+
23
+ def get_encoding(self, source: ArrayLike, training=False) -> Array:
24
+ """Convenience method for whole-image evaluation"""
25
+ return self.refine(self.encoder(source, training), training)
26
+
27
+ def get_params_at_coords(self, encoding: ArrayLike, coords: ArrayLike) -> Array:
28
+ encoding = interpolate_grid(coords, encoding)
29
+ phi_params = self.out_conv(encoding)
30
+
31
+ # reshape to output params shape
32
+ phi_params = jnp.split(
33
+ phi_params, np.cumsum([math.prod(s) for s in self.output_params_shape[:-1]]), axis=-1)
34
+ phi_params = [jnp.reshape(p, p.shape[:-1] + s) for p, s in
35
+ zip(phi_params, self.output_params_shape)]
36
+
37
+ return jax.tree_util.tree_unflatten(self.tree_def, phi_params)
38
+
39
+ def __call__(self, source: ArrayLike, target_coords: ArrayLike, training=False) -> Array:
40
+ encoding = self.get_encoding(source, training)
41
+ return self.get_params_at_coords(encoding, target_coords)
model/init.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jaxtyping import Array
6
+
7
+
8
+ def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
9
+ def init(key, shape, dtype=dtype) -> Array:
10
+ return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
11
+ return init
12
+
13
+
14
+ def linear_up(scale: float) -> Callable:
15
+ def init(key, shape, dtype=jnp.float32) -> Array:
16
+ assert shape[-2] == 2
17
+ keys = jax.random.split(key, 2)
18
+ norm = jnp.pi * scale * (
19
+ jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
20
+ theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
21
+ x = norm * jnp.cos(theta)
22
+ y = norm * jnp.sin(theta)
23
+ return jnp.concatenate([x, y], axis=-2).astype(dtype)
24
+ return init
model/rdn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Residual Dense Network for Image Super-Resolution
2
+ # https://arxiv.org/abs/1802.08797
3
+ # modified from: https://github.com/thstkdgus35/EDSR-PyTorch
4
+
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+
9
+ class RDB_Conv(nn.Module):
10
+ growRate: int
11
+ kSize: int = 3
12
+
13
+ @nn.compact
14
+ def __call__(self, x):
15
+ out = nn.Sequential([
16
+ nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
17
+ nn.activation.relu
18
+ ])(x)
19
+ return jnp.concatenate((x, out), -1)
20
+
21
+
22
+ class RDB(nn.Module):
23
+ growRate0: int
24
+ growRate: int
25
+ nConvLayers: int
26
+
27
+ @nn.compact
28
+ def __call__(self, x):
29
+ res = x
30
+
31
+ for c in range(self.nConvLayers):
32
+ x = RDB_Conv(self.growRate)(x)
33
+
34
+ x = nn.Conv(self.growRate0, (1, 1))(x)
35
+
36
+ return x + res
37
+
38
+
39
+ class RDN(nn.Module):
40
+ G0: int = 64
41
+ RDNkSize: int = 3
42
+ RDNconfig: str = 'B'
43
+ scale: int = 2
44
+ n_colors: int = 3
45
+
46
+ @nn.compact
47
+ def __call__(self, x, _=None):
48
+ D, C, G = {
49
+ 'A': (20, 6, 32),
50
+ 'B': (16, 8, 64),
51
+ }[self.RDNconfig]
52
+
53
+ # Shallow feature extraction
54
+ f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
55
+ x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
56
+
57
+ # Redidual dense blocks and dense feature fusion
58
+ RDBs_out = []
59
+ for i in range(D):
60
+ x = RDB(self.G0, G, C)(x)
61
+ RDBs_out.append(x)
62
+
63
+ x = jnp.concatenate(RDBs_out, -1)
64
+
65
+ # Global Feature Fusion
66
+ x = nn.Sequential([
67
+ nn.Conv(self.G0, (1, 1)),
68
+ nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
69
+ ])(x)
70
+
71
+ x = x + f_1
72
+ return x
model/swin_ir.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional, Iterable
3
+
4
+ import numpy as np
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ from jaxtyping import Array
9
+
10
+
11
+ def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
12
+ """Truncated normal initialization function"""
13
+
14
+ def init(key, shape, dtype=dtype) -> Array:
15
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ l = norm_cdf((a - mean) / std)
21
+ u = norm_cdf((b - mean) / std)
22
+ out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
23
+ out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
24
+ return jnp.clip(out, a, b)
25
+
26
+ return init
27
+
28
+
29
+ def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
30
+ return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
31
+
32
+
33
+ def LayerNorm():
34
+ """torch LayerNorm uses larger epsilon by default"""
35
+ return nn.LayerNorm(epsilon=1e-05)
36
+
37
+
38
+ class Mlp(nn.Module):
39
+
40
+ in_features: int
41
+ hidden_features: int = None
42
+ out_features: int = None
43
+ act_layer: Callable = nn.gelu
44
+ drop: float = 0.0
45
+
46
+ @nn.compact
47
+ def __call__(self, x, training: bool):
48
+ x = nn.Dense(self.hidden_features or self.in_features)(x)
49
+ x = self.act_layer(x)
50
+ x = nn.Dropout(self.drop, deterministic=not training)(x)
51
+ x = nn.Dense(self.out_features or self.in_features)(x)
52
+ x = nn.Dropout(self.drop, deterministic=not training)(x)
53
+ return x
54
+
55
+
56
+ def window_partition(x, window_size: int):
57
+ """
58
+ Args:
59
+ x: (B, H, W, C)
60
+ window_size (int): window size
61
+
62
+ Returns:
63
+ windows: (num_windows*B, window_size, window_size, C)
64
+ """
65
+ B, H, W, C = x.shape
66
+ x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
67
+ windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
68
+ return windows
69
+
70
+
71
+ def window_reverse(windows, window_size: int, H: int, W: int):
72
+ """
73
+ Args:
74
+ windows: (num_windows*B, window_size, window_size, C)
75
+ window_size (int): Window size
76
+ H (int): Height of image
77
+ W (int): Width of image
78
+
79
+ Returns:
80
+ x: (B, H, W, C)
81
+ """
82
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
83
+ x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
84
+ x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
85
+ return x
86
+
87
+
88
+ class DropPath(nn.Module):
89
+ """
90
+ Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
91
+ """
92
+
93
+ dropout_prob: float = 0.1
94
+ deterministic: Optional[bool] = None
95
+
96
+ @nn.compact
97
+ def __call__(self, input, training):
98
+ if not training:
99
+ return input
100
+ keep_prob = 1 - self.dropout_prob
101
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1)
102
+ rng = self.make_rng("dropout")
103
+ random_tensor = keep_prob + jax.random.uniform(rng, shape)
104
+ random_tensor = jnp.floor(random_tensor)
105
+ return jnp.divide(input, keep_prob) * random_tensor
106
+
107
+
108
+ class WindowAttention(nn.Module):
109
+ dim: int
110
+ window_size: Iterable[int]
111
+ num_heads: int
112
+ qkv_bias: bool = True
113
+ qk_scale: Optional[float] = None
114
+ att_drop: float = 0.0
115
+ proj_drop: float = 0.0
116
+
117
+ def make_rel_pos_index(self):
118
+ h_indices = np.arange(0, self.window_size[0])
119
+ w_indices = np.arange(0, self.window_size[1])
120
+ indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
121
+ flatten_indices = np.reshape(indices, (2, -1))
122
+ relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
123
+ relative_indices = np.transpose(relative_indices, (1, 2, 0))
124
+ relative_indices[:, :, 0] += self.window_size[0] - 1
125
+ relative_indices[:, :, 1] += self.window_size[1] - 1
126
+ relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
127
+ relative_pos_index = np.sum(relative_indices, -1)
128
+ return relative_pos_index
129
+
130
+ @nn.compact
131
+ def __call__(self, inputs, mask, training):
132
+ rpbt = self.param(
133
+ "relative_position_bias_table",
134
+ trunc_normal(std=.02),
135
+ (
136
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
137
+ self.num_heads,
138
+ ),
139
+ )
140
+
141
+ #relative_pos_index = self.variable(
142
+ # "variables", "relative_position_index", self.get_rel_pos_index
143
+ #)
144
+
145
+ batch, n, channels = inputs.shape
146
+ qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
147
+ qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
148
+ qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
149
+ q, k, v = qkv[0], qkv[1], qkv[2]
150
+
151
+ scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
152
+ q = q * scale
153
+ att = q @ jnp.swapaxes(k, -2, -1)
154
+
155
+ rel_pos_bias = jnp.reshape(
156
+ rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
157
+ (
158
+ self.window_size[0] * self.window_size[1],
159
+ self.window_size[0] * self.window_size[1],
160
+ -1,
161
+ ),
162
+ )
163
+ rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
164
+ att += jnp.expand_dims(rel_pos_bias, 0)
165
+
166
+ if mask is not None:
167
+ att = jnp.reshape(
168
+ att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
169
+ )
170
+ att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
171
+ att = jnp.reshape(att, (-1, self.num_heads, n, n))
172
+ att = jax.nn.softmax(att)
173
+
174
+ else:
175
+ att = jax.nn.softmax(att)
176
+
177
+ att = nn.Dropout(self.att_drop)(att, deterministic=not training)
178
+
179
+ x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
180
+ x = nn.Dense(self.dim, name="proj")(x)
181
+ x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
182
+ return x
183
+
184
+
185
+ class SwinTransformerBlock(nn.Module):
186
+
187
+ dim: int
188
+ input_resolution: tuple[int]
189
+ num_heads: int
190
+ window_size: int = 7
191
+ shift_size: int = 0
192
+ mlp_ratio: float = 4.
193
+ qkv_bias: bool = True
194
+ qk_scale: Optional[float] = None
195
+ drop: float = 0.
196
+ attn_drop: float = 0.
197
+ drop_path: float = 0.
198
+ act_layer: Callable = nn.activation.gelu
199
+ norm_layer: Callable = LayerNorm
200
+
201
+ @staticmethod
202
+ def make_att_mask(shift_size, window_size, height, width):
203
+ if shift_size > 0:
204
+ mask = jnp.zeros([1, height, width, 1])
205
+ h_slices = (
206
+ slice(0, -window_size),
207
+ slice(-window_size, -shift_size),
208
+ slice(-shift_size, None),
209
+ )
210
+ w_slices = (
211
+ slice(0, -window_size),
212
+ slice(-window_size, -shift_size),
213
+ slice(-shift_size, None),
214
+ )
215
+
216
+ count = 0
217
+ for h in h_slices:
218
+ for w in w_slices:
219
+ mask = mask.at[:, h, w, :].set(count)
220
+ count += 1
221
+
222
+ mask_windows = window_partition(mask, window_size)
223
+ mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
224
+ att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
225
+ att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
226
+ att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
227
+ else:
228
+ att_mask = None
229
+
230
+ return att_mask
231
+
232
+ @nn.compact
233
+ def __call__(self, x, x_size, training):
234
+ H, W = x_size
235
+ B, L, C = x.shape
236
+
237
+ if min(self.input_resolution) <= self.window_size:
238
+ # if window size is larger than input resolution, we don't partition windows
239
+ self.shift_size = 0
240
+ self.window_size = min(self.input_resolution)
241
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
242
+
243
+ shortcut = x
244
+ x = self.norm_layer()(x)
245
+ x = x.reshape((B, H, W, C))
246
+
247
+ # cyclic shift
248
+ if self.shift_size > 0:
249
+ shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
250
+ else:
251
+ shifted_x = x
252
+
253
+ # partition windows
254
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
255
+ x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
256
+
257
+ #attn_mask = self.variable(
258
+ # "variables",
259
+ # "attn_mask",
260
+ # self.get_att_mask,
261
+ # self.shift_size,
262
+ # self.window_size,
263
+ # self.input_resolution[0],
264
+ # self.input_resolution[1]
265
+ #)
266
+
267
+ attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
268
+
269
+ attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
270
+ self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
271
+ if self.input_resolution == x_size:
272
+ attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
273
+ else:
274
+ # test time
275
+ assert not training
276
+ test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
277
+ attn_windows = attn(x_windows, test_mask, training=False)
278
+
279
+ # merge windows
280
+ attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
281
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
282
+
283
+ # reverse cyclic shift
284
+ if self.shift_size > 0:
285
+ x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
286
+ else:
287
+ x = shifted_x
288
+
289
+ x = x.reshape((B, H * W, C))
290
+
291
+ # FFN
292
+ x = shortcut + DropPath(self.drop_path)(x, training)
293
+
294
+ norm = self.norm_layer()(x)
295
+ mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
296
+ act_layer=self.act_layer, drop=self.drop)(norm, training)
297
+ x = x + DropPath(self.drop_path)(mlp, training)
298
+
299
+ return x
300
+
301
+
302
+ class PatchMerging(nn.Module):
303
+ inp_res: Iterable[int]
304
+ dim: int
305
+ norm_layer: Callable = LayerNorm
306
+
307
+ @nn.compact
308
+ def __call__(self, inputs):
309
+ batch, n, channels = inputs.shape
310
+ height, width = self.inp_res[0], self.inp_res[1]
311
+ x = jnp.reshape(inputs, (batch, height, width, channels))
312
+
313
+ x0 = x[:, 0::2, 0::2, :]
314
+ x1 = x[:, 1::2, 0::2, :]
315
+ x2 = x[:, 0::2, 1::2, :]
316
+ x3 = x[:, 1::2, 1::2, :]
317
+
318
+ x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
319
+ x = jnp.reshape(x, (batch, -1, 4 * channels))
320
+ x = self.norm_layer()(x)
321
+ x = nn.Dense(2 * self.dim, use_bias=False)(x)
322
+ return x
323
+
324
+
325
+ class BasicLayer(nn.Module):
326
+
327
+ dim: int
328
+ input_resolution: int
329
+ depth: int
330
+ num_heads: int
331
+ window_size: int
332
+ mlp_ratio: float = 4.
333
+ qkv_bias: bool = True
334
+ qk_scale: Optional[float] = None
335
+ drop: float = 0.
336
+ attn_drop: float = 0.
337
+ drop_path: float = 0.
338
+ norm_layer: Callable = LayerNorm
339
+ downsample: Optional[Callable] = None
340
+
341
+ @nn.compact
342
+ def __call__(self, x, x_size, training):
343
+ for i in range(self.depth):
344
+ x = SwinTransformerBlock(
345
+ self.dim,
346
+ self.input_resolution,
347
+ self.num_heads,
348
+ self.window_size,
349
+ 0 if (i % 2 == 0) else self.window_size // 2,
350
+ self.mlp_ratio,
351
+ self.qkv_bias,
352
+ self.qk_scale,
353
+ self.drop,
354
+ self.attn_drop,
355
+ self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
356
+ norm_layer=self.norm_layer
357
+ )(x, x_size, training)
358
+
359
+ if self.downsample is not None:
360
+ x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
361
+
362
+ return x
363
+
364
+
365
+ class RSTB(nn.Module):
366
+
367
+ dim: int
368
+ input_resolution: int
369
+ depth: int
370
+ num_heads: int
371
+ window_size: int
372
+ mlp_ratio: float = 4.
373
+ qkv_bias: bool = True
374
+ qk_scale: Optional[float] = None
375
+ drop: float = 0.
376
+ attn_drop: float = 0.
377
+ drop_path: float = 0.
378
+ norm_layer: Callable = LayerNorm
379
+ downsample: Optional[Callable] = None
380
+ img_size: int = 224,
381
+ patch_size: int = 4,
382
+ resi_connection: str = '1conv'
383
+
384
+ @nn.compact
385
+ def __call__(self, x, x_size, training):
386
+ res = x
387
+ x = BasicLayer(dim=self.dim,
388
+ input_resolution=self.input_resolution,
389
+ depth=self.depth,
390
+ num_heads=self.num_heads,
391
+ window_size=self.window_size,
392
+ mlp_ratio=self.mlp_ratio,
393
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
394
+ drop=self.drop, attn_drop=self.attn_drop,
395
+ drop_path=self.drop_path,
396
+ norm_layer=self.norm_layer,
397
+ downsample=self.downsample)(x, x_size, training)
398
+
399
+ x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
400
+
401
+ # resi_connection == '1conv':
402
+ x = nn.Conv(self.dim, (3, 3))(x)
403
+
404
+ x = PatchEmbed()(x)
405
+
406
+ return x + res
407
+
408
+
409
+ class PatchEmbed(nn.Module):
410
+ norm_layer: Optional[Callable] = None
411
+
412
+ @nn.compact
413
+ def __call__(self, x):
414
+ x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
415
+ if self.norm_layer is not None:
416
+ x = self.norm_layer()(x)
417
+ return x
418
+
419
+
420
+ class PatchUnEmbed(nn.Module):
421
+ embed_dim: int = 96
422
+
423
+ @nn.compact
424
+ def __call__(self, x, x_size):
425
+ B, HW, C = x.shape
426
+ x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
427
+ return x
428
+
429
+
430
+ class SwinIR(nn.Module):
431
+ r""" SwinIR JAX implementation
432
+ Args:
433
+ img_size (int | tuple(int)): Input image size. Default 64
434
+ patch_size (int | tuple(int)): Patch size. Default: 1
435
+ in_chans (int): Number of input image channels. Default: 3
436
+ embed_dim (int): Patch embedding dimension. Default: 96
437
+ depths (tuple(int)): Depth of each Swin Transformer layer.
438
+ num_heads (tuple(int)): Number of attention heads in different layers.
439
+ window_size (int): Window size. Default: 7
440
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
441
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
442
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
443
+ drop_rate (float): Dropout rate. Default: 0
444
+ attn_drop_rate (float): Attention dropout rate. Default: 0
445
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
446
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
447
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
448
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
449
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
450
+ img_range: Image range. 1. or 25I think5.
451
+ """
452
+
453
+ img_size: int = 48
454
+ patch_size: int = 1
455
+ in_chans: int = 3
456
+ embed_dim: int = 180
457
+ depths: tuple = (6, 6, 6, 6, 6, 6)
458
+ num_heads: tuple = (6, 6, 6, 6, 6, 6)
459
+ window_size: int = 8
460
+ mlp_ratio: float = 2.
461
+ qkv_bias: bool = True
462
+ qk_scale: Optional[float] = None
463
+ drop_rate: float = 0.
464
+ attn_drop_rate: float = 0.
465
+ drop_path_rate: float = 0.1
466
+ norm_layer: Callable = LayerNorm
467
+ ape: bool = False
468
+ patch_norm: bool = True
469
+ upscale: int = 2
470
+ img_range: float = 1.
471
+ num_feat: int = 64
472
+
473
+ def pad(self, x):
474
+ _, h, w, _ = x.shape
475
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
476
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
477
+ x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
478
+ return x
479
+
480
+ @nn.compact
481
+ def __call__(self, x, training):
482
+ _, h_before, w_before, _ = x.shape
483
+ x = self.pad(x)
484
+ _, h, w, _ = x.shape
485
+ patches_resolution = [self.img_size // self.patch_size] * 2
486
+ num_patches = patches_resolution[0] * patches_resolution[1]
487
+
488
+ # conv_first
489
+ x = nn.Conv(self.embed_dim, (3, 3))(x)
490
+ res = x
491
+
492
+ # feature extraction
493
+ x_size = (h, w)
494
+ x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
495
+
496
+ if self.ape:
497
+ absolute_pos_embed = \
498
+ self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
499
+ x = x + absolute_pos_embed
500
+
501
+ x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
502
+
503
+ dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
504
+ for i_layer in range(len(self.depths)):
505
+ x = RSTB(
506
+ dim=self.embed_dim,
507
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
508
+ depth=self.depths[i_layer],
509
+ num_heads=self.num_heads[i_layer],
510
+ window_size=self.window_size,
511
+ mlp_ratio=self.mlp_ratio,
512
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
513
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
514
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
515
+ norm_layer=self.norm_layer,
516
+ downsample=None,
517
+ img_size=self.img_size,
518
+ patch_size=self.patch_size)(x, x_size, training)
519
+
520
+ x = self.norm_layer()(x) # B L C
521
+ x = PatchUnEmbed(self.embed_dim)(x, x_size)
522
+
523
+ # conv_after_body
524
+ x = nn.Conv(self.embed_dim, (3, 3))(x)
525
+ x = x + res
526
+
527
+ # conv_before_upsample
528
+ x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
529
+
530
+ # revert padding
531
+ x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
532
+ return x
model/tail.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+
3
+ from .convnext import ConvNeXt
4
+ from .swin_ir import SwinIR
5
+
6
+
7
+ def build_tail(size: str):
8
+ """ Convenience function to build the three tails described in the paper. """
9
+ if size == 'air':
10
+ return lambda x, _: x
11
+ elif size == 'plus':
12
+ blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3
13
+ return ConvNeXt(blocks)
14
+ elif size == 'pro':
15
+ return SwinIR(depths=[7, 6], num_heads=[6, 6])
16
+ else:
17
+ raise NotImplementedError('size: ' + size)
18
+
model/thera.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import jax
4
+ from flax.core import unfreeze, freeze
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+ from jaxtyping import Array, ArrayLike, PyTree
8
+
9
+ from .edsr import EDSR
10
+ from .rdn import RDN
11
+ from .hyper import Hypernetwork
12
+ from .tail import build_tail
13
+ from .init import uniform_between, linear_up
14
+ from utils import make_grid, interpolate_grid, repeat_vmap
15
+
16
+
17
+ class Thermal(nn.Module):
18
+ w0_scale: float = 1.
19
+
20
+ @nn.compact
21
+ def __call__(self, x: ArrayLike, t, norm, k) -> Array:
22
+ phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
23
+ return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
24
+
25
+
26
+ class TheraField(nn.Module):
27
+ dim_hidden: int
28
+ dim_out: int
29
+ w0: float = 1.
30
+ c: float = 6.
31
+
32
+ @nn.compact
33
+ def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
34
+ # coordinate projection according to shared components ("first layer")
35
+ x = x @ components
36
+
37
+ # thermal activations
38
+ norm = jnp.linalg.norm(components, axis=-2)
39
+ x = Thermal(self.w0)(x, t, norm, k)
40
+
41
+ # linear projection from hidden to output space ("second layer")
42
+ w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
43
+ dense_init_fn = uniform_between(-w_std, w_std)
44
+ x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
45
+
46
+ return x
47
+
48
+
49
+ class Thera:
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_dim: int,
54
+ out_dim: int,
55
+ backbone: nn.Module,
56
+ tail: nn.Module,
57
+ k_init: float = None,
58
+ components_init_scale: float = None
59
+ ):
60
+ self.hidden_dim = hidden_dim
61
+ self.k_init = k_init
62
+ self.components_init_scale = components_init_scale
63
+
64
+ # single TheraField object whose `apply` method is used for all grid cells
65
+ self.field = TheraField(hidden_dim, out_dim)
66
+
67
+ # infer output size of the hypernetwork from a sample pass through the field;
68
+ # key doesnt matter as field params are only used for size inference
69
+ sample_params = self.field.init(jax.random.PRNGKey(0),
70
+ jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
71
+ sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
72
+ param_shapes = [p.shape for p in sample_params_flat]
73
+
74
+ self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
75
+
76
+ def init(self, key, sample_source) -> PyTree:
77
+ keys = jax.random.split(key, 2)
78
+ sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
79
+ params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
80
+
81
+ params['params']['k'] = jnp.array(self.k_init)
82
+ params['params']['components'] = \
83
+ linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
84
+
85
+ return freeze(params)
86
+
87
+ def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
88
+ """
89
+ Performs a forward pass through the hypernetwork to obtain an encoding.
90
+ """
91
+ return self.hypernet.apply(
92
+ params, source, method=self.hypernet.get_encoding, **kwargs)
93
+
94
+ def apply_decoder(
95
+ self,
96
+ params: PyTree,
97
+ encoding: ArrayLike,
98
+ coords: ArrayLike,
99
+ t: ArrayLike,
100
+ return_jac: bool = False
101
+ ) -> Array | tuple[Array, Array]:
102
+ """
103
+ Performs a forward prediction through a grid of HxW Thera fields,
104
+ informed by `encoding`, at spatial and temporal coordinates
105
+ `coords` and `t`, respectively.
106
+ args:
107
+ params: Field parameters, shape (B, H, W, N)
108
+ encoding: Encoding tensor, shape (B, H, W, C)
109
+ coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
110
+ t: Temporal coordinates, shape (B, 1)
111
+ """
112
+ phi_params: PyTree = self.hypernet.apply(
113
+ params, encoding, coords, method=self.hypernet.get_params_at_coords)
114
+
115
+ # create local coordinate systems
116
+ source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
117
+ source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
118
+ interp_coords = interpolate_grid(coords, source_coords)
119
+ rel_coords = (coords - interp_coords)
120
+ rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
121
+ rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
122
+
123
+ # three maps over params, coords; one over t; dont map k and components
124
+ in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
125
+ apply_field = repeat_vmap(self.field.apply, in_axes)
126
+ out = apply_field(phi_params, rel_coords, t, params['params']['k'],
127
+ params['params']['components'])
128
+
129
+ if return_jac:
130
+ apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
131
+ jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
132
+ params['params']['components'])
133
+ return out, jac
134
+
135
+ return out
136
+
137
+ def apply(
138
+ self,
139
+ params: ArrayLike,
140
+ source: ArrayLike,
141
+ coords: ArrayLike,
142
+ t: ArrayLike,
143
+ return_jac: bool = False,
144
+ **kwargs
145
+ ) -> Array:
146
+ """
147
+ Performs a forward pass through the Thera model.
148
+ """
149
+ encoding = self.apply_encoder(params, source, **kwargs)
150
+ out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
151
+ return out
152
+
153
+
154
+ def build_thera(
155
+ out_dim: int,
156
+ backbone: str,
157
+ size: str,
158
+ k_init: float = None,
159
+ components_init_scale: float = None
160
+ ):
161
+ """
162
+ Convenience function for building the three Thera sizes described in the paper.
163
+ """
164
+ hidden_dim = 32 if size == 'air' else 512
165
+
166
+ if backbone == 'edsr-baseline':
167
+ backbone_module = EDSR(None, num_blocks=16, num_feats=64)
168
+ elif backbone == 'rdn':
169
+ backbone_module = RDN()
170
+ else:
171
+ raise NotImplementedError(backbone)
172
+
173
+ tail_module = build_tail(size)
174
+
175
+ return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2
+ # torch-cpu is sufficient since we only use it for data loading
3
+ --extra-index-url https://download.pytorch.org/whl/cpu
4
+
5
+ chex==0.1.7
6
+ ConfigArgParse==1.7
7
+ einops==0.6.1
8
+ flax==0.6.10
9
+ flaxmodels==0.1.3
10
+ jax==0.4.11
11
+ jaxlib==0.4.11+cuda11.cudnn86
12
+ jaxtyping==0.2.20
13
+ ml-dtypes==0.1.0
14
+ numpy==1.24.1
15
+ nvidia-cublas-cu11==11.11.3.6
16
+ nvidia-cuda-cupti-cu11==11.8.87
17
+ nvidia-cuda-nvcc-cu11==11.8.89
18
+ nvidia-cuda-runtime-cu11==11.8.89
19
+ nvidia-cudnn-cu11==8.9.2.26
20
+ nvidia-cufft-cu11==10.9.0.58
21
+ nvidia-cusolver-cu11==11.4.1.48
22
+ nvidia-cusparse-cu11==11.7.5.86
23
+ opt-einsum==3.3.0
24
+ optax==0.2.0
25
+ orbax-checkpoint==0.2.4
26
+ scipy==1.10.1
27
+ timm==0.9.6
28
+ torch==2.0.1+cpu
29
+ torchaudio==2.0.2+cpu
30
+ torchmetrics==1.2.0
31
+ torchvision==0.15.2+cpu
32
+ tqdm==4.65.0
33
+ transformers==4.46.3
34
+ Pillow==10.0.0
35
+ wandb
36
+
37
+ # gradio
38
+ gradio==4.44.1
39
+ gradio_imageslider==0.0.20
40
+ spaces
super_resolve.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from argparse import ArgumentParser, Namespace
4
+ import pickle
5
+
6
+ import jax
7
+ from jax import jit
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ from model import build_thera
13
+ from utils import make_grid, interpolate_grid
14
+
15
+ MEAN = jnp.array([.4488, .4371, .4040])
16
+ VAR = jnp.array([.25, .25, .25])
17
+ PATCH_SIZE = 256
18
+
19
+
20
+ def process_single(source, apply_encoder, apply_decoder, params, target_shape):
21
+ t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
22
+ coords_nearest = jnp.asarray(make_grid(target_shape)[None])
23
+ source_up = interpolate_grid(coords_nearest, source[None])
24
+ source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
25
+
26
+ encoding = apply_encoder(params, source)
27
+ coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
28
+ out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
29
+
30
+ for h_min in range(0, coords.shape[1], PATCH_SIZE):
31
+ h_max = min(h_min + PATCH_SIZE, coords.shape[1])
32
+ for w_min in range(0, coords.shape[2], PATCH_SIZE):
33
+ # apply decoder with one patch of coordinates
34
+ w_max = min(w_min + PATCH_SIZE, coords.shape[2])
35
+ coords_patch = coords[:, h_min:h_max, w_min:w_max]
36
+ out_patch = apply_decoder(params, encoding, coords_patch, t)
37
+ out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
38
+
39
+ out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
40
+ out += source_up
41
+ return out
42
+
43
+
44
+ def process(source, model, params, target_shape, do_ensemble=True):
45
+ apply_encoder = jit(model.apply_encoder)
46
+ apply_decoder = jit(model.apply_decoder)
47
+
48
+ outs = []
49
+ for i_rot in range(4 if do_ensemble else 1):
50
+ source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
51
+ target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
52
+ out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
53
+ outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
54
+
55
+ out = jnp.stack(outs).mean(0).clip(0., 1.)
56
+ return jnp.rint(out[0] * 255).astype(jnp.uint8)
57
+
58
+
59
+ def main(args: Namespace):
60
+ source = np.asarray(Image.open(args.in_file)) / 255.
61
+
62
+ if args.scale is not None:
63
+ if args.size is not None:
64
+ raise ValueError('Cannot specify both size and scale')
65
+ target_shape = (
66
+ round(source.shape[0] * args.scale),
67
+ round(source.shape[1] * args.scale),
68
+ )
69
+ elif args.size is not None:
70
+ target_shape = args.size
71
+ else:
72
+ raise ValueError('Must specify either size or scale')
73
+
74
+ with open(args.checkpoint, 'rb') as fh:
75
+ check = pickle.load(fh)
76
+ params, backbone, size = check['model'], check['backbone'], check['size']
77
+
78
+ model = build_thera(3, backbone, size)
79
+
80
+ out = process(source, model, params, target_shape, not args.no_ensemble)
81
+
82
+ Image.fromarray(np.asarray(out)).save(args.out_file)
83
+
84
+
85
+ def parse_args() -> Namespace:
86
+ parser = ArgumentParser()
87
+ parser.add_argument('in_file')
88
+ parser.add_argument('out_file')
89
+ parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
90
+ parser.add_argument('--size', type=int, nargs=2,
91
+ help='Target size (h, w), mutually exclusive with --scale')
92
+ parser.add_argument('--checkpoint', help='Path to checkpoint file')
93
+ parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
94
+ return parser.parse_args()
95
+
96
+
97
+ if __name__ == '__main__':
98
+ args = parse_args()
99
+ main(args)
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import jax
4
+ import numpy as np
5
+
6
+
7
+ def repeat_vmap(fun, in_axes=[0]):
8
+ for axes in in_axes:
9
+ fun = jax.vmap(fun, in_axes=axes)
10
+ return fun
11
+
12
+
13
+ def make_grid(patch_size: int | tuple[int, int]):
14
+ if isinstance(patch_size, int):
15
+ patch_size = (patch_size, patch_size)
16
+ offset_h, offset_w = 1 / (2 * np.array(patch_size))
17
+ space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
18
+ space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
19
+ return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
20
+
21
+
22
+ def interpolate_grid(coords, grid, order=0):
23
+ """
24
+ args:
25
+ coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
+ grid: Tensor of shape (B, H', W', C)
27
+ returns:
28
+ Tensor of shape (B, H, W, C) with interpolated values
29
+ """
30
+ # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
+ # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
+ coords = coords.transpose((0, 3, 1, 2))
33
+ coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
+ coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
+ return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)