Spaces:
Running
on
L4
Running
on
L4
Alexander Becker
commited on
Commit
·
b139995
1
Parent(s):
4fd4a46
Add code
Browse files- .gitattributes +1 -0
- .gitignore +174 -0
- app.py +201 -4
- checkpoints/thera-edsr-plus.pkl +3 -0
- files/1_skyscrapers_2.png +3 -0
- files/2_manga1.png +3 -0
- files/3_bird.png +3 -0
- files/koala.png +3 -0
- files/manga3.png +3 -0
- files/raw/0853.png +3 -0
- files/raw/0853C.png +3 -0
- files/raw/69015.png +3 -0
- files/raw/GakuenNoise.png +3 -0
- files/raw/GakuenNoiseC.png +3 -0
- files/raw/UchiNoNyansDiary_000.png +3 -0
- files/raw/UchiNoNyansDiary_000_C.png +3 -0
- files/zebra_8.png +3 -0
- gradio_dualvision/__init__.py +26 -0
- gradio_dualvision/app_template.py +614 -0
- gradio_dualvision/gradio_patches/__init__.py +0 -0
- gradio_dualvision/gradio_patches/examples.py +36 -0
- gradio_dualvision/gradio_patches/gallery.py +77 -0
- gradio_dualvision/gradio_patches/gallery.pyi +82 -0
- gradio_dualvision/gradio_patches/imagesliderplus.py +156 -0
- gradio_dualvision/gradio_patches/imagesliderplus.pyi +161 -0
- gradio_dualvision/gradio_patches/radio.py +62 -0
- gradio_dualvision/gradio_patches/radio.pyi +67 -0
- gradio_dualvision/gradio_patches/templates/component/__vite-browser-external-2447137e.js +4 -0
- gradio_dualvision/gradio_patches/templates/component/index.js +0 -0
- gradio_dualvision/gradio_patches/templates/component/style.css +1 -0
- gradio_dualvision/gradio_patches/templates/component/wrapper-6f348d45-19fa94bf.js +2453 -0
- gradio_dualvision/gradio_patches/templates/example/index.js +95 -0
- gradio_dualvision/gradio_patches/templates/example/style.css +1 -0
- gradio_dualvision/version.py +25 -0
- model/__init__.py +2 -0
- model/convnext.py +55 -0
- model/edsr.py +122 -0
- model/hyper.py +41 -0
- model/init.py +24 -0
- model/rdn.py +72 -0
- model/swin_ir.py +532 -0
- model/tail.py +18 -0
- model/thera.py +175 -0
- requirements.txt +40 -0
- super_resolve.py +99 -0
- utils.py +36 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
app.py
CHANGED
@@ -1,8 +1,205 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
from gradio_dualvision import DualVisionApp
|
7 |
+
from gradio_dualvision.gradio_patches.radio import Radio
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from model import build_thera
|
12 |
+
from super_resolve import process
|
13 |
+
|
14 |
+
CHECKPOINT = "checkpoints/thera-edsr-plus.pkl"
|
15 |
+
|
16 |
+
|
17 |
+
class TheraApp(DualVisionApp):
|
18 |
+
DEFAULT_SCALE = 3.1415
|
19 |
+
DEFAULT_DO_ENSEMBLE = False
|
20 |
+
|
21 |
+
def make_header(self):
|
22 |
+
gr.Markdown(
|
23 |
+
"""
|
24 |
+
## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields
|
25 |
+
<p align="center">
|
26 |
+
<a title="Website" href="https://therasr.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
27 |
+
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
|
28 |
+
</a>
|
29 |
+
<a title="arXiv" href="https://arxiv.org/pdf/2311.17643" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
30 |
+
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
|
31 |
+
</a>
|
32 |
+
<a title="Github" href="https://github.com/prs-eth/thera" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
33 |
+
<img src="https://img.shields.io/github/stars/prs-eth/thera?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
|
34 |
+
</a>
|
35 |
+
</p>
|
36 |
+
<p align="center" style="margin-top: 0px;">
|
37 |
+
Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!
|
38 |
+
</p>
|
39 |
+
"""
|
40 |
+
)
|
41 |
+
|
42 |
+
def build_user_components(self):
|
43 |
+
with gr.Column():
|
44 |
+
scale = gr.Slider(
|
45 |
+
label="Scaling factor",
|
46 |
+
minimum=1,
|
47 |
+
maximum=6,
|
48 |
+
step=0.01,
|
49 |
+
value=self.DEFAULT_SCALE,
|
50 |
+
)
|
51 |
+
do_ensemble = gr.Radio(
|
52 |
+
[
|
53 |
+
("No", False),
|
54 |
+
("Yes", True),
|
55 |
+
],
|
56 |
+
label="Do Ensemble",
|
57 |
+
value=self.DEFAULT_DO_ENSEMBLE,
|
58 |
+
)
|
59 |
+
return {
|
60 |
+
"scale": scale,
|
61 |
+
"do_ensemble": do_ensemble,
|
62 |
+
}
|
63 |
+
|
64 |
+
def process(self, image_in: Image.Image, **kwargs):
|
65 |
+
scale = kwargs.get("scale", self.DEFAULT_SCALE)
|
66 |
+
do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE)
|
67 |
+
|
68 |
+
source = np.asarray(image_in) / 255.
|
69 |
+
|
70 |
+
# determine target shape
|
71 |
+
target_shape = (
|
72 |
+
round(source.shape[0] * scale),
|
73 |
+
round(source.shape[1] * scale),
|
74 |
+
)
|
75 |
+
|
76 |
+
# load model
|
77 |
+
with open(CHECKPOINT, 'rb') as fh:
|
78 |
+
check = pickle.load(fh)
|
79 |
+
params, backbone, size = check['model'], check['backbone'], check['size']
|
80 |
+
|
81 |
+
model = build_thera(3, backbone, size)
|
82 |
+
|
83 |
+
out = process(source, model, params, target_shape, do_ensemble=do_ensemble)
|
84 |
+
out = Image.fromarray(np.asarray(out))
|
85 |
+
|
86 |
+
nearest = image_in.resize(out.size, Image.NEAREST)
|
87 |
+
|
88 |
+
out_modalities = {
|
89 |
+
"nearest": nearest,
|
90 |
+
"out": out,
|
91 |
+
}
|
92 |
+
out_settings = {
|
93 |
+
'scale': scale,
|
94 |
+
'do_ensemble': do_ensemble,
|
95 |
+
}
|
96 |
+
return out_modalities, out_settings
|
97 |
+
|
98 |
+
def process_components(
|
99 |
+
self, image_in, modality_selector_left, modality_selector_right, **kwargs
|
100 |
+
):
|
101 |
+
if image_in is None:
|
102 |
+
raise gr.Error("Input image is required")
|
103 |
+
|
104 |
+
image_settings = {}
|
105 |
+
if isinstance(image_in, str):
|
106 |
+
image_settings_path = image_in + ".settings.json"
|
107 |
+
if os.path.isfile(image_settings_path):
|
108 |
+
with open(image_settings_path, "r") as f:
|
109 |
+
image_settings = json.load(f)
|
110 |
+
image_in = Image.open(image_in).convert("RGB")
|
111 |
+
else:
|
112 |
+
if not isinstance(image_in, Image.Image):
|
113 |
+
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
|
114 |
+
image_in = image_in.convert("RGB")
|
115 |
+
image_settings.update(kwargs)
|
116 |
+
|
117 |
+
results_dict, results_settings = self.process(image_in, **image_settings)
|
118 |
+
|
119 |
+
if not isinstance(results_dict, dict):
|
120 |
+
raise gr.Error(
|
121 |
+
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}"
|
122 |
+
)
|
123 |
+
if len(results_dict) == 0:
|
124 |
+
raise gr.Error("`process` did not return any modalities")
|
125 |
+
for k, v in results_dict.items():
|
126 |
+
if not isinstance(k, str):
|
127 |
+
raise gr.Error(
|
128 |
+
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}"
|
129 |
+
)
|
130 |
+
if k == self.key_original_image:
|
131 |
+
raise gr.Error(
|
132 |
+
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
|
133 |
+
)
|
134 |
+
if not isinstance(v, Image.Image):
|
135 |
+
raise gr.Error(
|
136 |
+
f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
|
137 |
+
)
|
138 |
+
if len(results_settings) != len(self.input_keys):
|
139 |
+
raise gr.Error(
|
140 |
+
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})"
|
141 |
+
)
|
142 |
+
if any(k not in results_settings for k in self.input_keys):
|
143 |
+
raise gr.Error(f"Mismatching setgings keys")
|
144 |
+
results_settings = {
|
145 |
+
k: cls(**ctor_args, value=results_settings[k])
|
146 |
+
for k, cls, ctor_args in zip(
|
147 |
+
self.input_keys, self.input_cls, self.input_kwargs
|
148 |
+
)
|
149 |
+
}
|
150 |
+
|
151 |
+
results_dict = {
|
152 |
+
**results_dict,
|
153 |
+
self.key_original_image: image_in,
|
154 |
+
}
|
155 |
+
|
156 |
+
results_state = [[v, k] for k, v in results_dict.items()]
|
157 |
+
modalities = list(results_dict.keys())
|
158 |
+
|
159 |
+
modality_left = (
|
160 |
+
modality_selector_left
|
161 |
+
if modality_selector_left in modalities
|
162 |
+
else modalities[0]
|
163 |
+
)
|
164 |
+
modality_right = (
|
165 |
+
modality_selector_right
|
166 |
+
if modality_selector_right in modalities
|
167 |
+
else modalities[1]
|
168 |
+
)
|
169 |
|
170 |
+
return [
|
171 |
+
results_state, # goes to a gr.Gallery
|
172 |
+
[
|
173 |
+
results_dict[modality_left],
|
174 |
+
results_dict[modality_right],
|
175 |
+
], # ImageSliderPlus
|
176 |
+
Radio(
|
177 |
+
choices=modalities,
|
178 |
+
value=modality_left,
|
179 |
+
label="Left",
|
180 |
+
key="Left",
|
181 |
+
),
|
182 |
+
Radio(
|
183 |
+
choices=modalities if self.left_selector_visible else modalities[1:],
|
184 |
+
value=modality_right,
|
185 |
+
label="Right",
|
186 |
+
key="Right",
|
187 |
+
),
|
188 |
+
*results_settings.values(),
|
189 |
+
]
|
190 |
|
|
|
|
|
191 |
|
192 |
+
with TheraApp(
|
193 |
+
title="Thera Arbitrary-Scale Super-Resolution",
|
194 |
+
examples_path="files",
|
195 |
+
examples_per_page=12,
|
196 |
+
squeeze_canvas=True,
|
197 |
+
advanced_settings_can_be_half_width=False,
|
198 |
+
#spaces_zero_gpu_enabled=True,
|
199 |
+
) as demo:
|
200 |
+
demo.queue(
|
201 |
+
api_open=False,
|
202 |
+
).launch(
|
203 |
+
server_name="0.0.0.0",
|
204 |
+
server_port=7860,
|
205 |
+
)
|
checkpoints/thera-edsr-plus.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a805ca6f0486d9eba8f228200340a0e6aedde16529e11fc7b98dc26d830d9aa8
|
3 |
+
size 31632862
|
files/1_skyscrapers_2.png
ADDED
![]() |
Git LFS Details
|
files/2_manga1.png
ADDED
![]() |
Git LFS Details
|
files/3_bird.png
ADDED
![]() |
Git LFS Details
|
files/koala.png
ADDED
![]() |
Git LFS Details
|
files/manga3.png
ADDED
![]() |
Git LFS Details
|
files/raw/0853.png
ADDED
![]() |
Git LFS Details
|
files/raw/0853C.png
ADDED
![]() |
Git LFS Details
|
files/raw/69015.png
ADDED
![]() |
Git LFS Details
|
files/raw/GakuenNoise.png
ADDED
![]() |
Git LFS Details
|
files/raw/GakuenNoiseC.png
ADDED
![]() |
Git LFS Details
|
files/raw/UchiNoNyansDiary_000.png
ADDED
![]() |
Git LFS Details
|
files/raw/UchiNoNyansDiary_000_C.png
ADDED
![]() |
Git LFS Details
|
files/zebra_8.png
ADDED
![]() |
Git LFS Details
|
gradio_dualvision/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
|
25 |
+
from .version import __version__
|
26 |
+
from .app_template import DualVisionApp
|
gradio_dualvision/app_template.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
import glob
|
25 |
+
import json
|
26 |
+
import os
|
27 |
+
import re
|
28 |
+
|
29 |
+
import gradio as gr
|
30 |
+
import spaces
|
31 |
+
from PIL import Image
|
32 |
+
from gradio.components.base import Component
|
33 |
+
|
34 |
+
from .gradio_patches.examples import Examples
|
35 |
+
from .gradio_patches.gallery import Gallery
|
36 |
+
from .gradio_patches.imagesliderplus import ImageSliderPlus
|
37 |
+
from .gradio_patches.radio import Radio
|
38 |
+
|
39 |
+
|
40 |
+
class DualVisionApp(gr.Blocks):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
title,
|
44 |
+
examples_path="examples",
|
45 |
+
examples_per_page=12,
|
46 |
+
examples_cache="lazy",
|
47 |
+
squeeze_canvas=True,
|
48 |
+
squeeze_viewport_height_pct=75,
|
49 |
+
left_selector_visible=False,
|
50 |
+
advanced_settings_can_be_half_width=True,
|
51 |
+
key_original_image="Original",
|
52 |
+
spaces_zero_gpu_enabled=False,
|
53 |
+
spaces_zero_gpu_duration=None,
|
54 |
+
slider_position=0.5,
|
55 |
+
slider_line_color="#FFF",
|
56 |
+
slider_line_width="4px",
|
57 |
+
slider_arrows_color="#FFF",
|
58 |
+
slider_arrows_width="2px",
|
59 |
+
gallery_thumb_min_size="96px",
|
60 |
+
**kwargs,
|
61 |
+
):
|
62 |
+
"""
|
63 |
+
A wrapper around Gradio Blocks class that implements an image processing demo template. All the user has to do
|
64 |
+
is to subclass this class and implement two methods: `process` implementing the image processing, and
|
65 |
+
`build_user_components` implementing Gradio components reading the additional processing arguments.
|
66 |
+
Args:
|
67 |
+
title: Title of the application (str, required).
|
68 |
+
examples_path: Base path where examples will be searched (Default: `"examples"`).
|
69 |
+
examples_per_page: How many examples to show at the bottom of the app (Default: `12`).
|
70 |
+
examples_cache: Examples caching policy, corresponding to `cache_examples` argument of gradio.Examples (Default: `"lazy"`).
|
71 |
+
squeeze_canvas: When True, the image is fit to the browser viewport. When False, the image is fit to width (Default: `True`).
|
72 |
+
squeeze_viewport_height_pct: Percentage of the browser viewport height (Default: `75`).
|
73 |
+
left_selector_visible: Whether controls for changing modalities in the left part of the slider are visible (Default: `False`).
|
74 |
+
key_original_image: Name of the key under which the input image is shown in the modality selectors (Default: `"Original"`).
|
75 |
+
advanced_settings_can_be_half_width: Whether allow placing advanced settings dropdown in half-column space whenever possible (Default: `True`).
|
76 |
+
spaces_zero_gpu_enabled: When True, the app wraps the processing function with the ZeroGPU decorator.
|
77 |
+
spaces_zero_gpu_duration: Defines an integer duration in seconds passed into the ZeroGPU decorator.
|
78 |
+
slider_position: Position of the slider between 0 and 1 (Default: `0.5`).
|
79 |
+
slider_line_color: Color of the slider line (Default: `"#FFF"`).
|
80 |
+
slider_line_width: Width of the slider line (Default: `"4px"`).
|
81 |
+
slider_arrows_color: Color of the slider arrows (Default: `"#FFF"`).
|
82 |
+
slider_arrows_width: Width of the slider arrows (Default: `2px`).
|
83 |
+
gallery_thumb_min_size: Min size of the gallery thumbnail (Default: `96px`).
|
84 |
+
**kwargs: Any other arguments that Gradio Blocks class can take.
|
85 |
+
"""
|
86 |
+
squeeze_viewport_height_pct = int(squeeze_viewport_height_pct)
|
87 |
+
if not 50 <= squeeze_viewport_height_pct <= 100:
|
88 |
+
raise gr.Error(
|
89 |
+
"`squeeze_viewport_height_pct` should be an integer between 50 and 100."
|
90 |
+
)
|
91 |
+
if not os.path.isdir(examples_path):
|
92 |
+
raise gr.Error("`examples_path` should be a directory.")
|
93 |
+
if not 0 <= slider_position <= 1:
|
94 |
+
raise gr.Error("`slider_position` should be between 0 and 1.")
|
95 |
+
kwargs = {k: v for k, v in kwargs.items()}
|
96 |
+
kwargs["title"] = title
|
97 |
+
self.examples_path = examples_path
|
98 |
+
self.examples_per_page = examples_per_page
|
99 |
+
self.examples_cache = examples_cache
|
100 |
+
self.key_original_image = key_original_image
|
101 |
+
self.slider_position = slider_position
|
102 |
+
self.input_keys = None
|
103 |
+
self.input_cls = None
|
104 |
+
self.input_kwargs = None
|
105 |
+
self.left_selector_visible = left_selector_visible
|
106 |
+
self.advanced_settings_can_be_half_width = advanced_settings_can_be_half_width
|
107 |
+
if spaces_zero_gpu_enabled:
|
108 |
+
self.process_components = spaces.GPU(
|
109 |
+
self.process_components, duration=spaces_zero_gpu_duration
|
110 |
+
)
|
111 |
+
self.head = ""
|
112 |
+
self.head += """
|
113 |
+
<script>
|
114 |
+
let observerFooterButtons = new MutationObserver((mutationsList, observer) => {
|
115 |
+
const oldButtonLeft = document.querySelector(".show-api");
|
116 |
+
const oldButtonRight = document.querySelector(".built-with");
|
117 |
+
if (!oldButtonRight || !oldButtonLeft) {
|
118 |
+
return;
|
119 |
+
}
|
120 |
+
observer.disconnect();
|
121 |
+
|
122 |
+
const parentDiv = oldButtonLeft.parentNode;
|
123 |
+
if (!parentDiv) return;
|
124 |
+
|
125 |
+
const createButton = (referenceButton, text, href) => {
|
126 |
+
let newButton = referenceButton.cloneNode(true);
|
127 |
+
newButton.href = href;
|
128 |
+
newButton.textContent = text;
|
129 |
+
newButton.className = referenceButton.className;
|
130 |
+
newButton.style.textDecoration = "none";
|
131 |
+
newButton.style.display = "inline-block";
|
132 |
+
newButton.style.cursor = "pointer";
|
133 |
+
return newButton;
|
134 |
+
};
|
135 |
+
|
136 |
+
const newButton0 = createButton(oldButtonRight, "Built with Gradio DualVision", "https://github.com/toshas/gradio-dualvision");
|
137 |
+
const newButton1 = createButton(oldButtonRight, "Template by Anton Obukhov", "https://www.obukhov.ai");
|
138 |
+
const newButton2 = createButton(oldButtonRight, "Licensed under CC BY-SA 4.0", "http://creativecommons.org/licenses/by-sa/4.0/");
|
139 |
+
|
140 |
+
const separatorDiv = document.createElement("div");
|
141 |
+
separatorDiv.className = "svelte-1rjryqp";
|
142 |
+
separatorDiv.textContent = "·";
|
143 |
+
|
144 |
+
parentDiv.replaceChild(newButton0, oldButtonLeft);
|
145 |
+
parentDiv.replaceChild(newButton1, oldButtonRight);
|
146 |
+
parentDiv.appendChild(separatorDiv);
|
147 |
+
parentDiv.appendChild(newButton2);
|
148 |
+
});
|
149 |
+
observerFooterButtons.observe(document.body, { childList: true, subtree: true });
|
150 |
+
</script>
|
151 |
+
"""
|
152 |
+
if kwargs.get("analytics_enabled") is not False:
|
153 |
+
self.head += f"""
|
154 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
|
155 |
+
<script>
|
156 |
+
window.dataLayer = window.dataLayer || [];
|
157 |
+
function gtag() {{dataLayer.push(arguments);}}
|
158 |
+
gtag('js', new Date());
|
159 |
+
gtag('config', 'G-1FWSVCGZTG');
|
160 |
+
</script>
|
161 |
+
"""
|
162 |
+
self.css = f"""
|
163 |
+
body {{ /* tighten the layout */
|
164 |
+
flex-grow: 0 !important;
|
165 |
+
}}
|
166 |
+
.sliderrow {{ /* center the slider */
|
167 |
+
display: flex;
|
168 |
+
justify-content: center;
|
169 |
+
}}
|
170 |
+
.slider {{ /* center the slider */
|
171 |
+
display: flex;
|
172 |
+
justify-content: center;
|
173 |
+
width: 100%;
|
174 |
+
}}
|
175 |
+
.slider .disabled {{ /* hide the main slider before image load */
|
176 |
+
visibility: hidden;
|
177 |
+
}}
|
178 |
+
.slider .svelte-9gxdi0 {{ /* hide the component label in the top-left corner before image load */
|
179 |
+
visibility: hidden;
|
180 |
+
}}
|
181 |
+
.slider .svelte-kzcjhc .icon-wrap {{
|
182 |
+
height: 0px; /* remove unnecessary spaces in captions before image load */
|
183 |
+
}}
|
184 |
+
.slider .svelte-kzcjhc.wrap {{
|
185 |
+
padding-top: 0px; /* remove unnecessary spaces in captions before image load */
|
186 |
+
}}
|
187 |
+
.slider .svelte-3w3rth {{ /* hide the dummy icon from the right pane before image load */
|
188 |
+
visibility: hidden;
|
189 |
+
}}
|
190 |
+
.svelte-106mu0e.icon-buttons {{ /* download button */
|
191 |
+
/*right: unset;*/
|
192 |
+
/*left: 8px;*/
|
193 |
+
}}
|
194 |
+
.slider .fixed {{ /* fix the opacity of the right pane image */
|
195 |
+
background-color: var(--anim-block-background-fill);
|
196 |
+
}}
|
197 |
+
.slider .inner {{ /* style slider line */
|
198 |
+
width: {slider_line_width};
|
199 |
+
background: {slider_line_color};
|
200 |
+
}}
|
201 |
+
.slider .icon-wrap svg {{ /* style slider arrows */
|
202 |
+
stroke: {slider_arrows_color};
|
203 |
+
stroke-width: {slider_arrows_width};
|
204 |
+
}}
|
205 |
+
.slider .icon-wrap path {{ /* style slider arrows */
|
206 |
+
fill: {slider_arrows_color};
|
207 |
+
}}
|
208 |
+
.row_reverse {{
|
209 |
+
flex-direction: row-reverse;
|
210 |
+
}}
|
211 |
+
.gallery.svelte-l4wpk0 {{ /* make examples gallery tiles square */
|
212 |
+
width: max({gallery_thumb_min_size}, calc(100vw / 8));
|
213 |
+
height: max({gallery_thumb_min_size}, calc(100vw / 8));
|
214 |
+
}}
|
215 |
+
.gallery.svelte-l4wpk0 img {{ /* make examples gallery tiles square */
|
216 |
+
width: max({gallery_thumb_min_size}, calc(100vw / 8));
|
217 |
+
height: max({gallery_thumb_min_size}, calc(100vw / 8));
|
218 |
+
}}
|
219 |
+
.gallery.svelte-l4wpk0 img {{ /* remove slider line from previews */
|
220 |
+
clip-path: inset(0 0 0 0);
|
221 |
+
}}
|
222 |
+
.gallery.svelte-l4wpk0 span {{ /* remove slider line from previews */
|
223 |
+
visibility: hidden;
|
224 |
+
}}
|
225 |
+
h1, h2, h3 {{ /* center markdown headings */
|
226 |
+
text-align: center;
|
227 |
+
display: block;
|
228 |
+
}}
|
229 |
+
#settings-accordion {{
|
230 |
+
margin: 0 auto;
|
231 |
+
max-width: 500px;
|
232 |
+
}}
|
233 |
+
"""
|
234 |
+
if squeeze_canvas:
|
235 |
+
self.head += f"""
|
236 |
+
<script>
|
237 |
+
// fixes vertical size of the component when used inside of iframeResizer (on spaces)
|
238 |
+
function squeezeViewport() {{
|
239 |
+
if (typeof window.parentIFrame === "undefined") return;
|
240 |
+
const images = document.querySelectorAll('.slider img');
|
241 |
+
window.parentIFrame.getPageInfo((info) => {{
|
242 |
+
images.forEach((img) => {{
|
243 |
+
const imgMaxHeightNew = (info.clientHeight * {squeeze_viewport_height_pct}) / 100;
|
244 |
+
img.style.maxHeight = `${{imgMaxHeightNew}}px`;
|
245 |
+
// window.parentIFrame.size(0, null); // tighten the layout; body's flex-grow: 0 is less intrusive
|
246 |
+
}});
|
247 |
+
}});
|
248 |
+
}}
|
249 |
+
window.addEventListener('resize', squeezeViewport);
|
250 |
+
|
251 |
+
// fixes gradio-imageslider wrong position behavior when using fitting to content by triggering resize
|
252 |
+
let observer = new MutationObserver((mutationsList) => {{
|
253 |
+
const images = document.querySelectorAll('.slider img');
|
254 |
+
images.forEach((img) => {{
|
255 |
+
if (img.complete) {{
|
256 |
+
window.dispatchEvent(new Event('resize'));
|
257 |
+
}} else {{
|
258 |
+
img.onload = () => {{
|
259 |
+
window.dispatchEvent(new Event('resize'));
|
260 |
+
}}
|
261 |
+
}}
|
262 |
+
}});
|
263 |
+
}});
|
264 |
+
observer.observe(document.body, {{ childList: true, subtree: true }});
|
265 |
+
</script>
|
266 |
+
"""
|
267 |
+
self.css += f"""
|
268 |
+
.slider {{ /* make the slider dimensions fit to the uploaded content dimensions */
|
269 |
+
max-width: fit-content;
|
270 |
+
}}
|
271 |
+
.slider .half-wrap {{ /* make the empty component width almost full before image load */
|
272 |
+
width: 70vw;
|
273 |
+
}}
|
274 |
+
.slider img {{ /* Ensures image fits inside the viewport */
|
275 |
+
max-height: {squeeze_viewport_height_pct}vh;
|
276 |
+
}}
|
277 |
+
"""
|
278 |
+
else:
|
279 |
+
self.css += f"""
|
280 |
+
.slider .half-wrap {{ /* make the upload area full width */
|
281 |
+
width: 100%;
|
282 |
+
}}
|
283 |
+
"""
|
284 |
+
kwargs["css"] = kwargs.get("css", "") + self.css
|
285 |
+
kwargs["head"] = kwargs.get("head", "") + self.head
|
286 |
+
super().__init__(**kwargs)
|
287 |
+
with self:
|
288 |
+
self.make_interface()
|
289 |
+
|
290 |
+
def process(self, image_in: Image.Image, **kwargs):
|
291 |
+
"""
|
292 |
+
Process an input image into multiple modalities using the provided arguments or default settings.
|
293 |
+
Returns two dictionaries: one containing the modalities and another with the actual settings.
|
294 |
+
Override this method in a subclass.
|
295 |
+
"""
|
296 |
+
raise NotImplementedError("Please override the `process` method.")
|
297 |
+
|
298 |
+
def build_user_components(self):
|
299 |
+
"""
|
300 |
+
Create gradio components for the Advanced Settings dropdown, that will be passed into the `process` method.
|
301 |
+
Use gr.Row(), gr.Column(), and other context managers to arrange the components. Return them as a flat dict.
|
302 |
+
Override this method in a subclass.
|
303 |
+
"""
|
304 |
+
raise NotImplementedError("Please override the `build_user_components` method.")
|
305 |
+
|
306 |
+
def discover_examples(self):
|
307 |
+
"""
|
308 |
+
Looks for valid image filenames.
|
309 |
+
"""
|
310 |
+
pattern = re.compile(r".*\.(jpg|JPG|jpeg|JPEG|png|PNG)$")
|
311 |
+
paths = glob.glob(f"{self.examples_path}/*")
|
312 |
+
out = list(sorted(filter(pattern.match, paths)))
|
313 |
+
return out
|
314 |
+
|
315 |
+
def process_components(
|
316 |
+
self, image_in, modality_selector_left, modality_selector_right, **kwargs
|
317 |
+
):
|
318 |
+
"""
|
319 |
+
Wraps the call to `process`. Returns results in a structure used by the gallery, slider, radio components.
|
320 |
+
"""
|
321 |
+
if image_in is None:
|
322 |
+
raise gr.Error("Input image is required")
|
323 |
+
|
324 |
+
image_settings = {}
|
325 |
+
if isinstance(image_in, str):
|
326 |
+
image_settings_path = image_in + ".settings.json"
|
327 |
+
if os.path.isfile(image_settings_path):
|
328 |
+
with open(image_settings_path, "r") as f:
|
329 |
+
image_settings = json.load(f)
|
330 |
+
image_in = Image.open(image_in).convert("RGB")
|
331 |
+
else:
|
332 |
+
if not isinstance(image_in, Image.Image):
|
333 |
+
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
|
334 |
+
image_in = image_in.convert("RGB")
|
335 |
+
image_settings.update(kwargs)
|
336 |
+
|
337 |
+
results_dict, results_settings = self.process(image_in, **image_settings)
|
338 |
+
|
339 |
+
if not isinstance(results_dict, dict):
|
340 |
+
raise gr.Error(
|
341 |
+
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}"
|
342 |
+
)
|
343 |
+
if len(results_dict) == 0:
|
344 |
+
raise gr.Error("`process` did not return any modalities")
|
345 |
+
for k, v in results_dict.items():
|
346 |
+
if not isinstance(k, str):
|
347 |
+
raise gr.Error(
|
348 |
+
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}"
|
349 |
+
)
|
350 |
+
if k == self.key_original_image:
|
351 |
+
raise gr.Error(
|
352 |
+
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
|
353 |
+
)
|
354 |
+
if not isinstance(v, Image.Image):
|
355 |
+
raise gr.Error(
|
356 |
+
f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
|
357 |
+
)
|
358 |
+
if len(results_settings) != len(self.input_keys):
|
359 |
+
raise gr.Error(
|
360 |
+
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})"
|
361 |
+
)
|
362 |
+
if any(k not in results_settings for k in self.input_keys):
|
363 |
+
raise gr.Error(f"Mismatching setgings keys")
|
364 |
+
results_settings = {
|
365 |
+
k: cls(**ctor_args, value=results_settings[k])
|
366 |
+
for k, cls, ctor_args in zip(
|
367 |
+
self.input_keys, self.input_cls, self.input_kwargs
|
368 |
+
)
|
369 |
+
}
|
370 |
+
|
371 |
+
results_dict = {
|
372 |
+
self.key_original_image: image_in,
|
373 |
+
**results_dict,
|
374 |
+
}
|
375 |
+
|
376 |
+
results_state = [[v, k] for k, v in results_dict.items()]
|
377 |
+
modalities = list(results_dict.keys())
|
378 |
+
|
379 |
+
modality_left = (
|
380 |
+
modality_selector_left
|
381 |
+
if modality_selector_left in modalities
|
382 |
+
else modalities[0]
|
383 |
+
)
|
384 |
+
modality_right = (
|
385 |
+
modality_selector_right
|
386 |
+
if modality_selector_right in modalities
|
387 |
+
else modalities[1]
|
388 |
+
)
|
389 |
+
|
390 |
+
return [
|
391 |
+
results_state, # goes to a gr.Gallery
|
392 |
+
[
|
393 |
+
results_dict[modality_left],
|
394 |
+
results_dict[modality_right],
|
395 |
+
], # ImageSliderPlus
|
396 |
+
Radio(
|
397 |
+
choices=modalities,
|
398 |
+
value=modality_left,
|
399 |
+
label="Left",
|
400 |
+
key="Left",
|
401 |
+
),
|
402 |
+
Radio(
|
403 |
+
choices=modalities if self.left_selector_visible else modalities[1:],
|
404 |
+
value=modality_right,
|
405 |
+
label="Right",
|
406 |
+
key="Right",
|
407 |
+
),
|
408 |
+
*results_settings.values(),
|
409 |
+
]
|
410 |
+
|
411 |
+
def on_process_first(
|
412 |
+
self,
|
413 |
+
image_slider,
|
414 |
+
modality_selector_left=None,
|
415 |
+
modality_selector_right=None,
|
416 |
+
*args,
|
417 |
+
):
|
418 |
+
image_in = image_slider[0]
|
419 |
+
input_dict = {}
|
420 |
+
if len(args) > 0:
|
421 |
+
input_dict = {k: v for k, v in zip(self.input_keys, args)}
|
422 |
+
return self.process_components(
|
423 |
+
image_in, modality_selector_left, modality_selector_right, **input_dict
|
424 |
+
)
|
425 |
+
|
426 |
+
def on_process_subsequent(
|
427 |
+
self, results_state, modality_selector_left, modality_selector_right, *args
|
428 |
+
):
|
429 |
+
if results_state is None:
|
430 |
+
raise gr.Error("Upload an image first or use an example below.")
|
431 |
+
results_state = {k: v for v, k in results_state}
|
432 |
+
image_in = results_state[self.key_original_image]
|
433 |
+
input_dict = {k: v for k, v in zip(self.input_keys, args)}
|
434 |
+
return self.process_components(
|
435 |
+
image_in, modality_selector_left, modality_selector_right, **input_dict
|
436 |
+
)
|
437 |
+
|
438 |
+
def on_selector_change_left(
|
439 |
+
self, results_state, image_slider, modality_selector_left
|
440 |
+
):
|
441 |
+
results_state = {k: v for v, k in results_state}
|
442 |
+
return [results_state[modality_selector_left], image_slider[1]]
|
443 |
+
|
444 |
+
def on_selector_change_right(
|
445 |
+
self, results_state, image_slider, modality_selector_right
|
446 |
+
):
|
447 |
+
results_state = {k: v for v, k in results_state}
|
448 |
+
return [image_slider[0], results_state[modality_selector_right]]
|
449 |
+
|
450 |
+
def make_interface(self):
|
451 |
+
"""
|
452 |
+
Constructs the entire Gradio Blocks interface.
|
453 |
+
"""
|
454 |
+
self.make_header()
|
455 |
+
|
456 |
+
results_state = Gallery(visible=False)
|
457 |
+
|
458 |
+
image_slider = self.make_slider()
|
459 |
+
|
460 |
+
with gr.Row():
|
461 |
+
modality_selector_left, modality_selector_right = (
|
462 |
+
self.make_modality_selectors(reverse_visual_order=False)
|
463 |
+
)
|
464 |
+
user_components, btn_clear, btn_submit = self.make_advanced_settings()
|
465 |
+
|
466 |
+
self.make_examples(
|
467 |
+
image_slider,
|
468 |
+
[
|
469 |
+
results_state,
|
470 |
+
image_slider,
|
471 |
+
modality_selector_left,
|
472 |
+
modality_selector_right,
|
473 |
+
*user_components.values(),
|
474 |
+
],
|
475 |
+
)
|
476 |
+
|
477 |
+
image_slider.upload(
|
478 |
+
fn=self.on_process_first,
|
479 |
+
inputs=[
|
480 |
+
image_slider,
|
481 |
+
modality_selector_left,
|
482 |
+
modality_selector_right,
|
483 |
+
*user_components.values(),
|
484 |
+
],
|
485 |
+
outputs=[
|
486 |
+
results_state,
|
487 |
+
image_slider,
|
488 |
+
modality_selector_left,
|
489 |
+
modality_selector_right,
|
490 |
+
*user_components.values(),
|
491 |
+
],
|
492 |
+
)
|
493 |
+
|
494 |
+
btn_submit.click(
|
495 |
+
fn=self.on_process_subsequent,
|
496 |
+
inputs=[
|
497 |
+
results_state,
|
498 |
+
modality_selector_left,
|
499 |
+
modality_selector_right,
|
500 |
+
*user_components.values(),
|
501 |
+
],
|
502 |
+
outputs=[
|
503 |
+
results_state,
|
504 |
+
image_slider,
|
505 |
+
modality_selector_left,
|
506 |
+
modality_selector_right,
|
507 |
+
*user_components.values(),
|
508 |
+
],
|
509 |
+
)
|
510 |
+
|
511 |
+
btn_clear.click(
|
512 |
+
fn=lambda: (None, None),
|
513 |
+
inputs=[],
|
514 |
+
outputs=[image_slider, results_state],
|
515 |
+
)
|
516 |
+
|
517 |
+
modality_selector_left.input(
|
518 |
+
fn=self.on_selector_change_left,
|
519 |
+
inputs=[results_state, image_slider, modality_selector_left],
|
520 |
+
outputs=image_slider,
|
521 |
+
)
|
522 |
+
modality_selector_right.input(
|
523 |
+
fn=self.on_selector_change_right,
|
524 |
+
inputs=[results_state, image_slider, modality_selector_right],
|
525 |
+
outputs=image_slider,
|
526 |
+
)
|
527 |
+
|
528 |
+
def make_header(self):
|
529 |
+
"""
|
530 |
+
Create a header section with Markdown and HTML.
|
531 |
+
Default: just the project title.
|
532 |
+
"""
|
533 |
+
gr.Markdown(f"# {self.title}")
|
534 |
+
|
535 |
+
def make_slider(self):
|
536 |
+
with gr.Row(elem_classes="sliderrow"):
|
537 |
+
return ImageSliderPlus(
|
538 |
+
label=self.title,
|
539 |
+
type="filepath",
|
540 |
+
elem_classes="slider",
|
541 |
+
position=self.slider_position,
|
542 |
+
)
|
543 |
+
|
544 |
+
def make_modality_selectors(self, reverse_visual_order=False):
|
545 |
+
modality_selector_left = Radio(
|
546 |
+
choices=None,
|
547 |
+
value=None,
|
548 |
+
label="Left",
|
549 |
+
key="Left",
|
550 |
+
show_label=False,
|
551 |
+
container=False,
|
552 |
+
visible=self.left_selector_visible,
|
553 |
+
render=not reverse_visual_order,
|
554 |
+
)
|
555 |
+
modality_selector_right = Radio(
|
556 |
+
choices=None,
|
557 |
+
value=None,
|
558 |
+
label="Right",
|
559 |
+
key="Right",
|
560 |
+
show_label=False,
|
561 |
+
container=False,
|
562 |
+
elem_id="selector_right",
|
563 |
+
visible=False,
|
564 |
+
render=not reverse_visual_order,
|
565 |
+
)
|
566 |
+
if reverse_visual_order:
|
567 |
+
modality_selector_right.render()
|
568 |
+
modality_selector_left.render()
|
569 |
+
return modality_selector_left, modality_selector_right
|
570 |
+
|
571 |
+
def make_examples(self, inputs, outputs):
|
572 |
+
examples = self.discover_examples()
|
573 |
+
if not isinstance(examples, list):
|
574 |
+
raise gr.Error("`discover_examples` must return a list of paths")
|
575 |
+
if any(not os.path.isfile(path) for path in examples):
|
576 |
+
raise gr.Error("Not all example paths are valid files")
|
577 |
+
examples_dirname = os.path.basename(os.path.normpath(self.examples_path))
|
578 |
+
return Examples(
|
579 |
+
examples=[
|
580 |
+
(e, e) for e in examples
|
581 |
+
], # tuples like this seem to work better with the gallery
|
582 |
+
inputs=inputs,
|
583 |
+
outputs=outputs,
|
584 |
+
examples_per_page=self.examples_per_page,
|
585 |
+
cache_examples=self.examples_cache,
|
586 |
+
fn=self.on_process_first,
|
587 |
+
directory_name=examples_dirname,
|
588 |
+
)
|
589 |
+
|
590 |
+
def make_advanced_settings(self):
|
591 |
+
with gr.Accordion("Advanced Settings", open=False, elem_id="settings-accordion"):
|
592 |
+
user_components = self.build_user_components()
|
593 |
+
if not isinstance(user_components, dict) or any(
|
594 |
+
not isinstance(k, str) or not isinstance(v, Component)
|
595 |
+
for k, v in user_components.items()
|
596 |
+
):
|
597 |
+
raise gr.Error(
|
598 |
+
"`build_user_components` must return a dict of Gradio components with string keys. A dict of the "
|
599 |
+
"same structure will be passed into the `process` function."
|
600 |
+
)
|
601 |
+
with gr.Row():
|
602 |
+
btn_clear, btn_submit = self.make_buttons()
|
603 |
+
self.input_keys = list(user_components.keys())
|
604 |
+
self.input_cls = list(v.__class__ for v in user_components.values())
|
605 |
+
self.input_kwargs = [
|
606 |
+
{k: v for k, v in c.constructor_args.items() if k not in ("value")}
|
607 |
+
for c in user_components.values()
|
608 |
+
]
|
609 |
+
return user_components, btn_clear, btn_submit
|
610 |
+
|
611 |
+
def make_buttons(self):
|
612 |
+
btn_clear = gr.Button("Clear")
|
613 |
+
btn_submit = gr.Button("Apply", variant="primary")
|
614 |
+
return btn_clear, btn_submit
|
gradio_dualvision/gradio_patches/__init__.py
ADDED
File without changes
|
gradio_dualvision/gradio_patches/examples.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
from pathlib import Path
|
25 |
+
import gradio
|
26 |
+
from gradio.utils import get_cache_folder
|
27 |
+
|
28 |
+
|
29 |
+
class Examples(gradio.helpers.Examples):
|
30 |
+
def __init__(self, *args, directory_name=None, **kwargs):
|
31 |
+
super().__init__(*args, **kwargs, _initiated_directly=False)
|
32 |
+
if directory_name is not None:
|
33 |
+
self.cached_folder = get_cache_folder() / directory_name
|
34 |
+
self.cached_file = Path(self.cached_folder) / "log.csv"
|
35 |
+
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
|
36 |
+
self.create()
|
gradio_dualvision/gradio_patches/gallery.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from gradio.components.gallery import (
|
5 |
+
GalleryImageType,
|
6 |
+
CaptionedGalleryImageType,
|
7 |
+
GalleryImage,
|
8 |
+
GalleryData,
|
9 |
+
)
|
10 |
+
from pathlib import Path
|
11 |
+
from urllib.parse import urlparse
|
12 |
+
|
13 |
+
import gradio
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
from gradio_client.utils import is_http_url_like
|
17 |
+
|
18 |
+
from gradio import processing_utils, utils, wasm_utils
|
19 |
+
from gradio.data_classes import FileData
|
20 |
+
|
21 |
+
|
22 |
+
class Gallery(gradio.Gallery):
|
23 |
+
def postprocess(
|
24 |
+
self,
|
25 |
+
value: list[GalleryImageType | CaptionedGalleryImageType] | None,
|
26 |
+
) -> GalleryData:
|
27 |
+
"""
|
28 |
+
This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
|
29 |
+
format = "png" if img.mode == "I;16" else "webp"
|
30 |
+
"""
|
31 |
+
if value is None:
|
32 |
+
return GalleryData(root=[])
|
33 |
+
output = []
|
34 |
+
|
35 |
+
def _save(img):
|
36 |
+
url = None
|
37 |
+
caption = None
|
38 |
+
orig_name = None
|
39 |
+
if isinstance(img, (tuple, list)):
|
40 |
+
img, caption = img
|
41 |
+
if isinstance(img, np.ndarray):
|
42 |
+
file = processing_utils.save_img_array_to_cache(
|
43 |
+
img, cache_dir=self.GRADIO_CACHE, format=self.format
|
44 |
+
)
|
45 |
+
file_path = str(utils.abspath(file))
|
46 |
+
elif isinstance(img, PIL.Image.Image):
|
47 |
+
format = "png" #if img.mode == "I;16" else "webp"
|
48 |
+
file = processing_utils.save_pil_to_cache(
|
49 |
+
img, cache_dir=self.GRADIO_CACHE, format=format
|
50 |
+
)
|
51 |
+
file_path = str(utils.abspath(file))
|
52 |
+
elif isinstance(img, str):
|
53 |
+
file_path = img
|
54 |
+
if is_http_url_like(img):
|
55 |
+
url = img
|
56 |
+
orig_name = Path(urlparse(img).path).name
|
57 |
+
else:
|
58 |
+
url = None
|
59 |
+
orig_name = Path(img).name
|
60 |
+
elif isinstance(img, Path):
|
61 |
+
file_path = str(img)
|
62 |
+
orig_name = img.name
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Cannot process type as image: {type(img)}")
|
65 |
+
return GalleryImage(
|
66 |
+
image=FileData(path=file_path, url=url, orig_name=orig_name),
|
67 |
+
caption=caption,
|
68 |
+
)
|
69 |
+
|
70 |
+
if wasm_utils.IS_WASM:
|
71 |
+
for img in value:
|
72 |
+
output.append(_save(img))
|
73 |
+
else:
|
74 |
+
with ThreadPoolExecutor() as executor:
|
75 |
+
for o in executor.map(_save, value):
|
76 |
+
output.append(o)
|
77 |
+
return GalleryData(root=output)
|
gradio_dualvision/gradio_patches/gallery.pyi
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from gradio.components.gallery import (
|
5 |
+
GalleryImageType,
|
6 |
+
CaptionedGalleryImageType,
|
7 |
+
GalleryImage,
|
8 |
+
GalleryData,
|
9 |
+
)
|
10 |
+
from pathlib import Path
|
11 |
+
from urllib.parse import urlparse
|
12 |
+
|
13 |
+
import gradio
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
from gradio_client.utils import is_http_url_like
|
17 |
+
|
18 |
+
from gradio import processing_utils, utils, wasm_utils
|
19 |
+
from gradio.data_classes import FileData
|
20 |
+
|
21 |
+
from gradio.events import Dependency
|
22 |
+
|
23 |
+
class Gallery(gradio.Gallery):
|
24 |
+
def postprocess(
|
25 |
+
self,
|
26 |
+
value: list[GalleryImageType | CaptionedGalleryImageType] | None,
|
27 |
+
) -> GalleryData:
|
28 |
+
"""
|
29 |
+
This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
|
30 |
+
format = "png" if img.mode == "I;16" else "webp"
|
31 |
+
"""
|
32 |
+
if value is None:
|
33 |
+
return GalleryData(root=[])
|
34 |
+
output = []
|
35 |
+
|
36 |
+
def _save(img):
|
37 |
+
url = None
|
38 |
+
caption = None
|
39 |
+
orig_name = None
|
40 |
+
if isinstance(img, (tuple, list)):
|
41 |
+
img, caption = img
|
42 |
+
if isinstance(img, np.ndarray):
|
43 |
+
file = processing_utils.save_img_array_to_cache(
|
44 |
+
img, cache_dir=self.GRADIO_CACHE, format=self.format
|
45 |
+
)
|
46 |
+
file_path = str(utils.abspath(file))
|
47 |
+
elif isinstance(img, PIL.Image.Image):
|
48 |
+
format = "png" #if img.mode == "I;16" else "webp"
|
49 |
+
file = processing_utils.save_pil_to_cache(
|
50 |
+
img, cache_dir=self.GRADIO_CACHE, format=format
|
51 |
+
)
|
52 |
+
file_path = str(utils.abspath(file))
|
53 |
+
elif isinstance(img, str):
|
54 |
+
file_path = img
|
55 |
+
if is_http_url_like(img):
|
56 |
+
url = img
|
57 |
+
orig_name = Path(urlparse(img).path).name
|
58 |
+
else:
|
59 |
+
url = None
|
60 |
+
orig_name = Path(img).name
|
61 |
+
elif isinstance(img, Path):
|
62 |
+
file_path = str(img)
|
63 |
+
orig_name = img.name
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Cannot process type as image: {type(img)}")
|
66 |
+
return GalleryImage(
|
67 |
+
image=FileData(path=file_path, url=url, orig_name=orig_name),
|
68 |
+
caption=caption,
|
69 |
+
)
|
70 |
+
|
71 |
+
if wasm_utils.IS_WASM:
|
72 |
+
for img in value:
|
73 |
+
output.append(_save(img))
|
74 |
+
else:
|
75 |
+
with ThreadPoolExecutor() as executor:
|
76 |
+
for o in executor.map(_save, value):
|
77 |
+
output.append(o)
|
78 |
+
return GalleryData(root=output)
|
79 |
+
from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
|
80 |
+
from gradio.blocks import Block
|
81 |
+
if TYPE_CHECKING:
|
82 |
+
from gradio.components import Timer
|
gradio_dualvision/gradio_patches/imagesliderplus.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
import json
|
25 |
+
import os.path
|
26 |
+
import tempfile
|
27 |
+
from pathlib import Path
|
28 |
+
from typing import Union, Tuple, Optional
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
from PIL import Image
|
32 |
+
from gradio import processing_utils
|
33 |
+
from gradio import utils
|
34 |
+
from gradio.data_classes import FileData, GradioRootModel, JsonData
|
35 |
+
from gradio_client import utils as client_utils
|
36 |
+
from gradio_imageslider import ImageSlider
|
37 |
+
from gradio_imageslider.imageslider import image_tuple, image_variants
|
38 |
+
|
39 |
+
|
40 |
+
class ImageSliderPlusData(GradioRootModel):
|
41 |
+
root: Union[
|
42 |
+
Tuple[FileData | None, FileData | None, JsonData | None],
|
43 |
+
Tuple[FileData | None, FileData | None],
|
44 |
+
None,
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
class ImageSliderPlus(ImageSlider):
|
49 |
+
data_model = ImageSliderPlusData
|
50 |
+
|
51 |
+
def as_example(self, value):
|
52 |
+
return self.process_example_dims(value, 256, True)
|
53 |
+
|
54 |
+
def _format_image(self, im: Image):
|
55 |
+
if self.type != "filepath":
|
56 |
+
raise ValueError("ImageSliderPlus can be only created with type='filepath'")
|
57 |
+
if im is None:
|
58 |
+
return im
|
59 |
+
format = "png" #if im.mode == "I;16" else "webp"
|
60 |
+
path = processing_utils.save_pil_to_cache(
|
61 |
+
im, cache_dir=self.GRADIO_CACHE, format=format
|
62 |
+
)
|
63 |
+
self.temp_files.add(path)
|
64 |
+
return path
|
65 |
+
|
66 |
+
def _postprocess_image(self, y: image_variants):
|
67 |
+
if isinstance(y, np.ndarray):
|
68 |
+
format = "png" #if y.dtype == np.uint16 and y.squeeze().ndim == 2 else "webp"
|
69 |
+
path = processing_utils.save_img_array_to_cache(
|
70 |
+
y, cache_dir=self.GRADIO_CACHE, format=format
|
71 |
+
)
|
72 |
+
elif isinstance(y, Image.Image):
|
73 |
+
format = "png" #if y.mode == "I;16" else "webp"
|
74 |
+
path = processing_utils.save_pil_to_cache(
|
75 |
+
y, cache_dir=self.GRADIO_CACHE, format=format
|
76 |
+
)
|
77 |
+
elif isinstance(y, (str, Path)):
|
78 |
+
path = y if isinstance(y, str) else str(utils.abspath(y))
|
79 |
+
else:
|
80 |
+
raise ValueError("Cannot process this value as an Image")
|
81 |
+
|
82 |
+
return path
|
83 |
+
|
84 |
+
def postprocess(
|
85 |
+
self,
|
86 |
+
y: image_tuple,
|
87 |
+
) -> ImageSliderPlusData:
|
88 |
+
if y is None:
|
89 |
+
return ImageSliderPlusData(root=(None, None, None))
|
90 |
+
|
91 |
+
settings = None
|
92 |
+
if type(y[0]) is str:
|
93 |
+
settings_candidate_path = y[0] + ".settings.json"
|
94 |
+
if os.path.isfile(settings_candidate_path):
|
95 |
+
with open(settings_candidate_path, "r") as fp:
|
96 |
+
settings = json.load(fp)
|
97 |
+
|
98 |
+
return ImageSliderPlusData(
|
99 |
+
root=(
|
100 |
+
FileData(path=self._postprocess_image(y[0])),
|
101 |
+
FileData(path=self._postprocess_image(y[1])),
|
102 |
+
JsonData(settings),
|
103 |
+
),
|
104 |
+
)
|
105 |
+
|
106 |
+
def preprocess(self, x: ImageSliderPlusData) -> image_tuple:
|
107 |
+
if x is None:
|
108 |
+
return x
|
109 |
+
|
110 |
+
out_0 = self._preprocess_image(x.root[0])
|
111 |
+
out_1 = self._preprocess_image(x.root[1])
|
112 |
+
|
113 |
+
if len(x.root) > 2 and x.root[2] is not None:
|
114 |
+
with open(out_0 + ".settings.json", "w") as fp:
|
115 |
+
json.dump(x.root[2].root, fp)
|
116 |
+
|
117 |
+
return out_0, out_1
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
|
121 |
+
img = Image.open(image_path).convert("RGB")
|
122 |
+
if square:
|
123 |
+
width, height = img.size
|
124 |
+
min_side = min(width, height)
|
125 |
+
left = (width - min_side) // 2
|
126 |
+
top = (height - min_side) // 2
|
127 |
+
right = left + min_side
|
128 |
+
bottom = top + min_side
|
129 |
+
img = img.crop((left, top, right, bottom))
|
130 |
+
img.thumbnail((max_dim, max_dim))
|
131 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
132 |
+
img.save(temp_file.name, "PNG")
|
133 |
+
return temp_file.name
|
134 |
+
|
135 |
+
def process_example_dims(
|
136 |
+
self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
|
137 |
+
) -> image_tuple:
|
138 |
+
if input_data is None:
|
139 |
+
return None
|
140 |
+
input_data = (str(input_data[0]), str(input_data[1]))
|
141 |
+
if self.proxy_url or client_utils.is_http_url_like(input_data[0]):
|
142 |
+
return input_data[0]
|
143 |
+
if max_dim is not None:
|
144 |
+
input_data = (
|
145 |
+
self.resize_and_save(input_data[0], max_dim, square),
|
146 |
+
self.resize_and_save(input_data[1], max_dim, square),
|
147 |
+
)
|
148 |
+
return (
|
149 |
+
self.move_resource_to_block_cache(input_data[0]),
|
150 |
+
self.move_resource_to_block_cache(input_data[1]),
|
151 |
+
)
|
152 |
+
|
153 |
+
def process_example(
|
154 |
+
self, input_data: tuple[str | Path | None] | None
|
155 |
+
) -> image_tuple:
|
156 |
+
return self.process_example_dims(input_data)
|
gradio_dualvision/gradio_patches/imagesliderplus.pyi
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
import json
|
25 |
+
import os.path
|
26 |
+
import tempfile
|
27 |
+
from pathlib import Path
|
28 |
+
from typing import Union, Tuple, Optional
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
from PIL import Image
|
32 |
+
from gradio import processing_utils
|
33 |
+
from gradio import utils
|
34 |
+
from gradio.data_classes import FileData, GradioRootModel, JsonData
|
35 |
+
from gradio_client import utils as client_utils
|
36 |
+
from gradio_imageslider import ImageSlider
|
37 |
+
from gradio_imageslider.imageslider import image_tuple, image_variants
|
38 |
+
|
39 |
+
|
40 |
+
class ImageSliderPlusData(GradioRootModel):
|
41 |
+
root: Union[
|
42 |
+
Tuple[FileData | None, FileData | None, JsonData | None],
|
43 |
+
Tuple[FileData | None, FileData | None],
|
44 |
+
None,
|
45 |
+
]
|
46 |
+
|
47 |
+
from gradio.events import Dependency
|
48 |
+
|
49 |
+
class ImageSliderPlus(ImageSlider):
|
50 |
+
data_model = ImageSliderPlusData
|
51 |
+
|
52 |
+
def as_example(self, value):
|
53 |
+
return self.process_example_dims(value, 256, True)
|
54 |
+
|
55 |
+
def _format_image(self, im: Image):
|
56 |
+
if self.type != "filepath":
|
57 |
+
raise ValueError("ImageSliderPlus can be only created with type='filepath'")
|
58 |
+
if im is None:
|
59 |
+
return im
|
60 |
+
format = "png" #if im.mode == "I;16" else "webp"
|
61 |
+
path = processing_utils.save_pil_to_cache(
|
62 |
+
im, cache_dir=self.GRADIO_CACHE, format=format
|
63 |
+
)
|
64 |
+
self.temp_files.add(path)
|
65 |
+
return path
|
66 |
+
|
67 |
+
def _postprocess_image(self, y: image_variants):
|
68 |
+
if isinstance(y, np.ndarray):
|
69 |
+
format = "png" #if y.dtype == np.uint16 and y.squeeze().ndim == 2 else "webp"
|
70 |
+
path = processing_utils.save_img_array_to_cache(
|
71 |
+
y, cache_dir=self.GRADIO_CACHE, format=format
|
72 |
+
)
|
73 |
+
elif isinstance(y, Image.Image):
|
74 |
+
format = "png" #if y.mode == "I;16" else "webp"
|
75 |
+
path = processing_utils.save_pil_to_cache(
|
76 |
+
y, cache_dir=self.GRADIO_CACHE, format=format
|
77 |
+
)
|
78 |
+
elif isinstance(y, (str, Path)):
|
79 |
+
path = y if isinstance(y, str) else str(utils.abspath(y))
|
80 |
+
else:
|
81 |
+
raise ValueError("Cannot process this value as an Image")
|
82 |
+
|
83 |
+
return path
|
84 |
+
|
85 |
+
def postprocess(
|
86 |
+
self,
|
87 |
+
y: image_tuple,
|
88 |
+
) -> ImageSliderPlusData:
|
89 |
+
if y is None:
|
90 |
+
return ImageSliderPlusData(root=(None, None, None))
|
91 |
+
|
92 |
+
settings = None
|
93 |
+
if type(y[0]) is str:
|
94 |
+
settings_candidate_path = y[0] + ".settings.json"
|
95 |
+
if os.path.isfile(settings_candidate_path):
|
96 |
+
with open(settings_candidate_path, "r") as fp:
|
97 |
+
settings = json.load(fp)
|
98 |
+
|
99 |
+
return ImageSliderPlusData(
|
100 |
+
root=(
|
101 |
+
FileData(path=self._postprocess_image(y[0])),
|
102 |
+
FileData(path=self._postprocess_image(y[1])),
|
103 |
+
JsonData(settings),
|
104 |
+
),
|
105 |
+
)
|
106 |
+
|
107 |
+
def preprocess(self, x: ImageSliderPlusData) -> image_tuple:
|
108 |
+
if x is None:
|
109 |
+
return x
|
110 |
+
|
111 |
+
out_0 = self._preprocess_image(x.root[0])
|
112 |
+
out_1 = self._preprocess_image(x.root[1])
|
113 |
+
|
114 |
+
if len(x.root) > 2 and x.root[2] is not None:
|
115 |
+
with open(out_0 + ".settings.json", "w") as fp:
|
116 |
+
json.dump(x.root[2].root, fp)
|
117 |
+
|
118 |
+
return out_0, out_1
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
|
122 |
+
img = Image.open(image_path).convert("RGB")
|
123 |
+
if square:
|
124 |
+
width, height = img.size
|
125 |
+
min_side = min(width, height)
|
126 |
+
left = (width - min_side) // 2
|
127 |
+
top = (height - min_side) // 2
|
128 |
+
right = left + min_side
|
129 |
+
bottom = top + min_side
|
130 |
+
img = img.crop((left, top, right, bottom))
|
131 |
+
img.thumbnail((max_dim, max_dim))
|
132 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
133 |
+
img.save(temp_file.name, "PNG")
|
134 |
+
return temp_file.name
|
135 |
+
|
136 |
+
def process_example_dims(
|
137 |
+
self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
|
138 |
+
) -> image_tuple:
|
139 |
+
if input_data is None:
|
140 |
+
return None
|
141 |
+
input_data = (str(input_data[0]), str(input_data[1]))
|
142 |
+
if self.proxy_url or client_utils.is_http_url_like(input_data[0]):
|
143 |
+
return input_data[0]
|
144 |
+
if max_dim is not None:
|
145 |
+
input_data = (
|
146 |
+
self.resize_and_save(input_data[0], max_dim, square),
|
147 |
+
self.resize_and_save(input_data[1], max_dim, square),
|
148 |
+
)
|
149 |
+
return (
|
150 |
+
self.move_resource_to_block_cache(input_data[0]),
|
151 |
+
self.move_resource_to_block_cache(input_data[1]),
|
152 |
+
)
|
153 |
+
|
154 |
+
def process_example(
|
155 |
+
self, input_data: tuple[str | Path | None] | None
|
156 |
+
) -> image_tuple:
|
157 |
+
return self.process_example_dims(input_data)
|
158 |
+
from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
|
159 |
+
from gradio.blocks import Block
|
160 |
+
if TYPE_CHECKING:
|
161 |
+
from gradio.components import Timer
|
gradio_dualvision/gradio_patches/radio.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
import gradio
|
25 |
+
from gradio import components
|
26 |
+
from gradio.components.base import Component
|
27 |
+
from gradio.data_classes import (
|
28 |
+
GradioModel,
|
29 |
+
GradioRootModel,
|
30 |
+
)
|
31 |
+
|
32 |
+
from gradio.blocks import BlockContext
|
33 |
+
|
34 |
+
|
35 |
+
def patched_postprocess_update_dict(
|
36 |
+
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
This is a patched version of the original function where 'pop' is replaced with 'get' in the first line.
|
40 |
+
The key will no longer be removed but can still be accessed safely.
|
41 |
+
This fixed gradio.Radio component persisting the value selection through gradio.Examples.
|
42 |
+
"""
|
43 |
+
value = update_dict.get("value", components._Keywords.NO_VALUE)
|
44 |
+
|
45 |
+
# Continue with the original logic
|
46 |
+
update_dict = {k: getattr(block, k) for k in update_dict if hasattr(block, k)}
|
47 |
+
if value is not components._Keywords.NO_VALUE:
|
48 |
+
if postprocess:
|
49 |
+
update_dict["value"] = block.postprocess(value)
|
50 |
+
if isinstance(update_dict["value"], (GradioModel, GradioRootModel)):
|
51 |
+
update_dict["value"] = update_dict["value"].model_dump()
|
52 |
+
else:
|
53 |
+
update_dict["value"] = value
|
54 |
+
update_dict["__type__"] = "update"
|
55 |
+
return update_dict
|
56 |
+
|
57 |
+
|
58 |
+
gradio.blocks.postprocess_update_dict = patched_postprocess_update_dict
|
59 |
+
|
60 |
+
|
61 |
+
class Radio(gradio.Radio):
|
62 |
+
pass
|
gradio_dualvision/gradio_patches/radio.pyi
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
import gradio
|
25 |
+
from gradio import components
|
26 |
+
from gradio.components.base import Component
|
27 |
+
from gradio.data_classes import (
|
28 |
+
GradioModel,
|
29 |
+
GradioRootModel,
|
30 |
+
)
|
31 |
+
|
32 |
+
from gradio.blocks import BlockContext
|
33 |
+
|
34 |
+
|
35 |
+
def patched_postprocess_update_dict(
|
36 |
+
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
This is a patched version of the original function where 'pop' is replaced with 'get' in the first line.
|
40 |
+
The key will no longer be removed but can still be accessed safely.
|
41 |
+
This fixed gradio.Radio component persisting the value selection through gradio.Examples.
|
42 |
+
"""
|
43 |
+
value = update_dict.get("value", components._Keywords.NO_VALUE)
|
44 |
+
|
45 |
+
# Continue with the original logic
|
46 |
+
update_dict = {k: getattr(block, k) for k in update_dict if hasattr(block, k)}
|
47 |
+
if value is not components._Keywords.NO_VALUE:
|
48 |
+
if postprocess:
|
49 |
+
update_dict["value"] = block.postprocess(value)
|
50 |
+
if isinstance(update_dict["value"], (GradioModel, GradioRootModel)):
|
51 |
+
update_dict["value"] = update_dict["value"].model_dump()
|
52 |
+
else:
|
53 |
+
update_dict["value"] = value
|
54 |
+
update_dict["__type__"] = "update"
|
55 |
+
return update_dict
|
56 |
+
|
57 |
+
|
58 |
+
gradio.blocks.postprocess_update_dict = patched_postprocess_update_dict
|
59 |
+
|
60 |
+
from gradio.events import Dependency
|
61 |
+
|
62 |
+
class Radio(gradio.Radio):
|
63 |
+
pass
|
64 |
+
from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
|
65 |
+
from gradio.blocks import Block
|
66 |
+
if TYPE_CHECKING:
|
67 |
+
from gradio.components import Timer
|
gradio_dualvision/gradio_patches/templates/component/__vite-browser-external-2447137e.js
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const e = {};
|
2 |
+
export {
|
3 |
+
e as default
|
4 |
+
};
|
gradio_dualvision/gradio_patches/templates/component/index.js
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gradio_dualvision/gradio_patches/templates/component/style.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.wrap.svelte-1w37x6c.svelte-1w37x6c{position:relative;width:100%;height:100%;z-index:100}.icon-wrap.svelte-1w37x6c.svelte-1w37x6c{position:absolute;top:50%;transform:translateY(-50%);left:-40px;width:32px;transition:.2s;color:var(--body-text-color)}.icon-wrap.right.svelte-1w37x6c.svelte-1w37x6c{left:60px;transform:translateY(-50%) translate(-100%) rotate(180deg)}.icon-wrap.active.svelte-1w37x6c.svelte-1w37x6c,.icon-wrap.disabled.svelte-1w37x6c.svelte-1w37x6c{opacity:0}.outer.svelte-1w37x6c.svelte-1w37x6c{width:20px;height:100%;cursor:grab;position:absolute;top:0;left:0}.inner.svelte-1w37x6c.svelte-1w37x6c{box-shadow:-1px 0 6px 1px #0003;width:1px;height:100%;background:var(--color);position:absolute;left:calc((100% - 2px)/2)}.disabled.svelte-1w37x6c.svelte-1w37x6c{cursor:auto}.disabled.svelte-1w37x6c .inner.svelte-1w37x6c{box-shadow:none}.block.svelte-1t38q2d{position:relative;margin:0;box-shadow:var(--block-shadow);border-width:var(--block-border-width);border-color:var(--block-border-color);border-radius:var(--block-radius);background:var(--block-background-fill);width:100%;line-height:var(--line-sm)}.block.border_focus.svelte-1t38q2d{border-color:var(--color-accent)}.padded.svelte-1t38q2d{padding:var(--block-padding)}.hidden.svelte-1t38q2d{display:none}.hide-container.svelte-1t38q2d{margin:0;box-shadow:none;--block-border-width:0;background:transparent;padding:0;overflow:visible}div.svelte-1hnfib2{margin-bottom:var(--spacing-lg);color:var(--block-info-text-color);font-weight:var(--block-info-text-weight);font-size:var(--block-info-text-size);line-height:var(--line-sm)}span.has-info.svelte-22c38v{margin-bottom:var(--spacing-xs)}span.svelte-22c38v:not(.has-info){margin-bottom:var(--spacing-lg)}span.svelte-22c38v{display:inline-block;position:relative;z-index:var(--layer-4);border:solid var(--block-title-border-width) var(--block-title-border-color);border-radius:var(--block-title-radius);background:var(--block-title-background-fill);padding:var(--block-title-padding);color:var(--block-title-text-color);font-weight:var(--block-title-text-weight);font-size:var(--block-title-text-size);line-height:var(--line-sm)}.hide.svelte-22c38v{margin:0;height:0}label.svelte-9gxdi0{display:inline-flex;align-items:center;z-index:var(--layer-2);box-shadow:var(--block-label-shadow);border:var(--block-label-border-width) solid var(--border-color-primary);border-top:none;border-left:none;border-radius:var(--block-label-radius);background:var(--block-label-background-fill);padding:var(--block-label-padding);pointer-events:none;color:var(--block-label-text-color);font-weight:var(--block-label-text-weight);font-size:var(--block-label-text-size);line-height:var(--line-sm)}.gr-group label.svelte-9gxdi0{border-top-left-radius:0}label.float.svelte-9gxdi0{position:absolute;top:var(--block-label-margin);left:var(--block-label-margin)}label.svelte-9gxdi0:not(.float){position:static;margin-top:var(--block-label-margin);margin-left:var(--block-label-margin)}.hide.svelte-9gxdi0{height:0}span.svelte-9gxdi0{opacity:.8;margin-right:var(--size-2);width:calc(var(--block-label-text-size) - 1px);height:calc(var(--block-label-text-size) - 1px)}.hide-label.svelte-9gxdi0{box-shadow:none;border-width:0;background:transparent;overflow:visible}button.svelte-lpi64a{display:flex;justify-content:center;align-items:center;gap:1px;z-index:var(--layer-2);border-radius:var(--radius-sm);color:var(--block-label-text-color);border:1px solid transparent}button[disabled].svelte-lpi64a{opacity:.5;box-shadow:none}button[disabled].svelte-lpi64a:hover{cursor:not-allowed}.padded.svelte-lpi64a{padding:2px;background:var(--bg-color);box-shadow:var(--shadow-drop);border:1px solid var(--button-secondary-border-color)}button.svelte-lpi64a:hover,button.highlight.svelte-lpi64a{cursor:pointer;color:var(--color-accent)}.padded.svelte-lpi64a:hover{border:2px solid var(--button-secondary-border-color-hover);padding:1px;color:var(--block-label-text-color)}span.svelte-lpi64a{padding:0 1px;font-size:10px}div.svelte-lpi64a{padding:2px;display:flex;align-items:flex-end}.small.svelte-lpi64a{width:14px;height:14px}.large.svelte-lpi64a{width:22px;height:22px}.pending.svelte-lpi64a{animation:svelte-lpi64a-flash .5s infinite}@keyframes svelte-lpi64a-flash{0%{opacity:.5}50%{opacity:1}to{opacity:.5}}.transparent.svelte-lpi64a{background:transparent;border:none;box-shadow:none}.empty.svelte-3w3rth{display:flex;justify-content:center;align-items:center;margin-top:calc(0px - var(--size-6));height:var(--size-full)}.icon.svelte-3w3rth{opacity:.5;height:var(--size-5);color:var(--body-text-color)}.small.svelte-3w3rth{min-height:calc(var(--size-32) - 20px)}.large.svelte-3w3rth{min-height:calc(var(--size-64) - 20px)}.unpadded_box.svelte-3w3rth{margin-top:0}.small_parent.svelte-3w3rth{min-height:100%!important}.dropdown-arrow.svelte-145leq6{fill:currentColor}.wrap.svelte-kzcjhc{display:flex;flex-direction:column;justify-content:center;align-items:center;min-height:var(--size-60);color:var(--block-label-text-color);line-height:var(--line-md);height:100%;padding-top:var(--size-3)}.or.svelte-kzcjhc{color:var(--body-text-color-subdued);display:flex}.icon-wrap.svelte-kzcjhc{width:30px;margin-bottom:var(--spacing-lg)}@media (--screen-md){.wrap.svelte-kzcjhc{font-size:var(--text-lg)}}.hovered.svelte-kzcjhc{color:var(--color-accent)}div.svelte-ipfyu7{border-top:1px solid transparent;display:flex;max-height:100%;justify-content:center;gap:var(--spacing-sm);height:auto;align-items:flex-end;padding-bottom:var(--spacing-xl);color:var(--block-label-text-color);flex-shrink:0;width:95%}.show_border.svelte-ipfyu7{border-top:1px solid var(--block-border-color);margin-top:var(--spacing-xxl);box-shadow:var(--shadow-drop)}.source-selection.svelte-lde7lt{display:flex;align-items:center;justify-content:center;border-top:1px solid var(--border-color-primary);width:95%;bottom:0;left:0;right:0;margin-left:auto;margin-right:auto;align-self:flex-end}.icon.svelte-lde7lt{width:22px;height:22px;margin:var(--spacing-lg) var(--spacing-xs);padding:var(--spacing-xs);color:var(--neutral-400);border-radius:var(--radius-md)}.selected.svelte-lde7lt{color:var(--color-accent)}.icon.svelte-lde7lt:hover,.icon.svelte-lde7lt:focus{color:var(--color-accent)}div.svelte-1g74h68{display:flex;position:absolute;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-5)}.wrap.svelte-1juivz4.svelte-1juivz4{overflow-y:auto;transition:opacity .5s ease-in-out;background:var(--block-background-fill);position:relative;display:flex;flex-direction:column;align-items:center;justify-content:center;min-height:var(--size-40);width:100%}.wrap.svelte-1juivz4.svelte-1juivz4:after{content:"";position:absolute;top:0;left:0;width:var(--upload-progress-width);height:100%;transition:all .5s ease-in-out;z-index:1}.uploading.svelte-1juivz4.svelte-1juivz4{font-size:var(--text-lg);font-family:var(--font);z-index:2}.file-name.svelte-1juivz4.svelte-1juivz4{margin:var(--spacing-md);font-size:var(--text-lg);color:var(--body-text-color-subdued)}.file.svelte-1juivz4.svelte-1juivz4{font-size:var(--text-md);z-index:2;display:flex;align-items:center}.file.svelte-1juivz4 progress.svelte-1juivz4{display:inline;height:var(--size-1);width:100%;transition:all .5s ease-in-out;color:var(--color-accent);border:none}.file.svelte-1juivz4 progress[value].svelte-1juivz4::-webkit-progress-value{background-color:var(--color-accent);border-radius:20px}.file.svelte-1juivz4 progress[value].svelte-1juivz4::-webkit-progress-bar{background-color:var(--border-color-accent);border-radius:20px}.progress-bar.svelte-1juivz4.svelte-1juivz4{width:14px;height:14px;border-radius:50%;background:radial-gradient(closest-side,var(--block-background-fill) 64%,transparent 53% 100%),conic-gradient(var(--color-accent) var(--upload-progress-width),var(--border-color-accent) 0);transition:all .5s ease-in-out}button.svelte-1aq8tno{cursor:pointer;width:var(--size-full)}.hidden.svelte-1aq8tno{display:none;height:0!important;position:absolute;width:0;flex-grow:0}.center.svelte-1aq8tno{display:flex;justify-content:center}.flex.svelte-1aq8tno{display:flex;justify-content:center;align-items:center}input.svelte-1aq8tno{display:none}div.svelte-1wj0ocy{display:flex;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-1)}.not-absolute.svelte-1wj0ocy{margin:var(--size-1)}.upload-wrap.svelte-106mu0e.svelte-106mu0e{display:flex;justify-content:center;align-items:center;height:100%;width:100%}.wrap.svelte-106mu0e.svelte-106mu0e{width:100%}.half-wrap.svelte-106mu0e.svelte-106mu0e{width:50%}.image-container.svelte-106mu0e.svelte-106mu0e,img.svelte-106mu0e.svelte-106mu0e,.empty-wrap.svelte-106mu0e.svelte-106mu0e{width:var(--size-full);height:var(--size-full)}img.svelte-106mu0e.svelte-106mu0e{object-fit:cover}.fixed.svelte-106mu0e.svelte-106mu0e{--anim-block-background-fill:255, 255, 255;position:absolute;top:0;left:0;background-color:rgba(var(--anim-block-background-fill),.8);z-index:0}@media (prefers-color-scheme: dark){.fixed.svelte-106mu0e.svelte-106mu0e{--anim-block-background-fill:31, 41, 55}}.side-by-side.svelte-106mu0e img.svelte-106mu0e{width:50%;object-fit:contain}.empty-wrap.svelte-106mu0e.svelte-106mu0e{pointer-events:none}.icon-buttons.svelte-106mu0e.svelte-106mu0e{display:flex;position:absolute;right:8px;z-index:var(--layer-top);top:8px}svg.svelte-43sxxs.svelte-43sxxs{width:var(--size-20);height:var(--size-20)}svg.svelte-43sxxs path.svelte-43sxxs{fill:var(--loader-color)}div.svelte-43sxxs.svelte-43sxxs{z-index:var(--layer-2)}.margin.svelte-43sxxs.svelte-43sxxs{margin:var(--size-4)}.wrap.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;justify-content:center;align-items:center;z-index:var(--layer-top);transition:opacity .1s ease-in-out;border-radius:var(--block-radius);background:var(--block-background-fill);padding:0 var(--size-6);max-height:var(--size-screen-h);overflow:hidden;pointer-events:none}.wrap.center.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;left:0}.wrap.default.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;bottom:0;left:0}.hide.svelte-1txqlrd.svelte-1txqlrd{opacity:0;pointer-events:none}.generating.svelte-1txqlrd.svelte-1txqlrd{animation:svelte-1txqlrd-pulse 2s cubic-bezier(.4,0,.6,1) infinite;border:2px solid var(--color-accent);background:transparent}.translucent.svelte-1txqlrd.svelte-1txqlrd{background:none}@keyframes svelte-1txqlrd-pulse{0%,to{opacity:1}50%{opacity:.5}}.loading.svelte-1txqlrd.svelte-1txqlrd{z-index:var(--layer-2);color:var(--body-text-color)}.eta-bar.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;bottom:0;left:0;transform-origin:left;opacity:.8;z-index:var(--layer-1);transition:10ms;background:var(--background-fill-secondary)}.progress-bar-wrap.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary);background:var(--background-fill-primary);width:55.5%;height:var(--size-4)}.progress-bar.svelte-1txqlrd.svelte-1txqlrd{transform-origin:left;background-color:var(--loader-color);width:var(--size-full);height:var(--size-full)}.progress-level.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;align-items:center;gap:1;z-index:var(--layer-2);width:var(--size-full)}.progress-level-inner.svelte-1txqlrd.svelte-1txqlrd{margin:var(--size-2) auto;color:var(--body-text-color);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text-center.svelte-1txqlrd.svelte-1txqlrd{display:flex;position:absolute;top:0;right:0;justify-content:center;align-items:center;transform:translateY(var(--size-6));z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono);text-align:center}.error.svelte-1txqlrd.svelte-1txqlrd{box-shadow:var(--shadow-drop);border:solid 1px var(--error-border-color);border-radius:var(--radius-full);background:var(--error-background-fill);padding-right:var(--size-4);padding-left:var(--size-4);color:var(--error-text-color);font-weight:var(--weight-semibold);font-size:var(--text-lg);line-height:var(--line-lg);font-family:var(--font)}.minimal.svelte-1txqlrd .progress-text.svelte-1txqlrd{background:var(--block-background-fill)}.border.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary)}.toast-body.svelte-solcu7{display:flex;position:relative;right:0;left:0;align-items:center;margin:var(--size-6) var(--size-4);margin:auto;border-radius:var(--container-radius);overflow:hidden;pointer-events:auto}.toast-body.error.svelte-solcu7{border:1px solid var(--color-red-700);background:var(--color-red-50)}.dark .toast-body.error.svelte-solcu7{border:1px solid var(--color-red-500);background-color:var(--color-grey-950)}.toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-700);background:var(--color-yellow-50)}.dark .toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-500);background-color:var(--color-grey-950)}.toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-700);background:var(--color-grey-50)}.dark .toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-500);background-color:var(--color-grey-950)}.toast-title.svelte-solcu7{display:flex;align-items:center;font-weight:var(--weight-bold);font-size:var(--text-lg);line-height:var(--line-sm);text-transform:capitalize}.toast-title.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-title.error.svelte-solcu7{color:var(--color-red-50)}.toast-title.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-title.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-title.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-title.info.svelte-solcu7{color:var(--color-grey-50)}.toast-close.svelte-solcu7{margin:0 var(--size-3);border-radius:var(--size-3);padding:0px var(--size-1-5);font-size:var(--size-5);line-height:var(--size-5)}.toast-close.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-close.error.svelte-solcu7{color:var(--color-red-500)}.toast-close.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-close.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-close.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-close.info.svelte-solcu7{color:var(--color-grey-500)}.toast-text.svelte-solcu7{font-size:var(--text-lg)}.toast-text.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-text.error.svelte-solcu7{color:var(--color-red-50)}.toast-text.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-text.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-text.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-text.info.svelte-solcu7{color:var(--color-grey-50)}.toast-details.svelte-solcu7{margin:var(--size-3) var(--size-3) var(--size-3) 0;width:100%}.toast-icon.svelte-solcu7{display:flex;position:absolute;position:relative;flex-shrink:0;justify-content:center;align-items:center;margin:var(--size-2);border-radius:var(--radius-full);padding:var(--size-1);padding-left:calc(var(--size-1) - 1px);width:35px;height:35px}.toast-icon.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-icon.error.svelte-solcu7{color:var(--color-red-500)}.toast-icon.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-icon.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-icon.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-icon.info.svelte-solcu7{color:var(--color-grey-500)}@keyframes svelte-solcu7-countdown{0%{transform:scaleX(1)}to{transform:scaleX(0)}}.timer.svelte-solcu7{position:absolute;bottom:0;left:0;transform-origin:0 0;animation:svelte-solcu7-countdown 10s linear forwards;width:100%;height:var(--size-1)}.timer.error.svelte-solcu7{background:var(--color-red-700)}.dark .timer.error.svelte-solcu7{background:var(--color-red-500)}.timer.warning.svelte-solcu7{background:var(--color-yellow-700)}.dark .timer.warning.svelte-solcu7{background:var(--color-yellow-500)}.timer.info.svelte-solcu7{background:var(--color-grey-700)}.dark .timer.info.svelte-solcu7{background:var(--color-grey-500)}.toast-wrap.svelte-gatr8h{display:flex;position:fixed;top:var(--size-4);right:var(--size-4);flex-direction:column;align-items:end;gap:var(--size-2);z-index:var(--layer-top);width:calc(100% - var(--size-8))}@media (--screen-sm){.toast-wrap.svelte-gatr8h{width:calc(var(--size-96) + var(--size-10))}}.slider-wrap.svelte-a2zf8o{-webkit-user-select:none;user-select:none;max-height:calc(100vh - 40px)}img.svelte-a2zf8o{width:var(--size-full);height:var(--size-full);object-fit:cover}.fixed.svelte-a2zf8o{position:absolute;top:0;left:0}.hidden.svelte-a2zf8o{opacity:0}.icon-buttons.svelte-a2zf8o{display:flex;position:absolute;right:8px;z-index:var(--layer-top);top:8px}.status-wrap.svelte-6wvohu{position:absolute;height:100%;width:100%;--anim-block-background-fill:255, 255, 255;z-index:1;pointer-events:none}@media (prefers-color-scheme: dark){.status-wrap.svelte-6wvohu{--anim-block-background-fill:31, 41, 55}}@keyframes svelte-6wvohu-pulse{0%{background-color:rgba(var(--anim-block-background-fill),.7)}50%{background-color:rgba(var(--anim-block-background-fill),.4)}to{background-color:rgba(var(--anim-block-background-fill),.7)}}.status-wrap.half.svelte-6wvohu .wrap{border-radius:0;animation:svelte-6wvohu-pulse 1.4s infinite ease-in-out}.status-wrap.half.svelte-6wvohu .progress-text{background:none!important}.status-wrap.half.svelte-6wvohu .eta-bar{opacity:0}.icon-buttons.svelte-6wvohu{display:flex;position:absolute;right:6px;z-index:var(--layer-1);top:6px}
|
gradio_dualvision/gradio_patches/templates/component/wrapper-6f348d45-19fa94bf.js
ADDED
@@ -0,0 +1,2453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import S from "./__vite-browser-external-2447137e.js";
|
2 |
+
function z(s) {
|
3 |
+
return s && s.__esModule && Object.prototype.hasOwnProperty.call(s, "default") ? s.default : s;
|
4 |
+
}
|
5 |
+
function gt(s) {
|
6 |
+
if (s.__esModule)
|
7 |
+
return s;
|
8 |
+
var e = s.default;
|
9 |
+
if (typeof e == "function") {
|
10 |
+
var t = function r() {
|
11 |
+
if (this instanceof r) {
|
12 |
+
var i = [null];
|
13 |
+
i.push.apply(i, arguments);
|
14 |
+
var n = Function.bind.apply(e, i);
|
15 |
+
return new n();
|
16 |
+
}
|
17 |
+
return e.apply(this, arguments);
|
18 |
+
};
|
19 |
+
t.prototype = e.prototype;
|
20 |
+
} else
|
21 |
+
t = {};
|
22 |
+
return Object.defineProperty(t, "__esModule", { value: !0 }), Object.keys(s).forEach(function(r) {
|
23 |
+
var i = Object.getOwnPropertyDescriptor(s, r);
|
24 |
+
Object.defineProperty(t, r, i.get ? i : {
|
25 |
+
enumerable: !0,
|
26 |
+
get: function() {
|
27 |
+
return s[r];
|
28 |
+
}
|
29 |
+
});
|
30 |
+
}), t;
|
31 |
+
}
|
32 |
+
const { Duplex: yt } = S;
|
33 |
+
function Oe(s) {
|
34 |
+
s.emit("close");
|
35 |
+
}
|
36 |
+
function vt() {
|
37 |
+
!this.destroyed && this._writableState.finished && this.destroy();
|
38 |
+
}
|
39 |
+
function Qe(s) {
|
40 |
+
this.removeListener("error", Qe), this.destroy(), this.listenerCount("error") === 0 && this.emit("error", s);
|
41 |
+
}
|
42 |
+
function St(s, e) {
|
43 |
+
let t = !0;
|
44 |
+
const r = new yt({
|
45 |
+
...e,
|
46 |
+
autoDestroy: !1,
|
47 |
+
emitClose: !1,
|
48 |
+
objectMode: !1,
|
49 |
+
writableObjectMode: !1
|
50 |
+
});
|
51 |
+
return s.on("message", function(n, o) {
|
52 |
+
const l = !o && r._readableState.objectMode ? n.toString() : n;
|
53 |
+
r.push(l) || s.pause();
|
54 |
+
}), s.once("error", function(n) {
|
55 |
+
r.destroyed || (t = !1, r.destroy(n));
|
56 |
+
}), s.once("close", function() {
|
57 |
+
r.destroyed || r.push(null);
|
58 |
+
}), r._destroy = function(i, n) {
|
59 |
+
if (s.readyState === s.CLOSED) {
|
60 |
+
n(i), process.nextTick(Oe, r);
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
let o = !1;
|
64 |
+
s.once("error", function(f) {
|
65 |
+
o = !0, n(f);
|
66 |
+
}), s.once("close", function() {
|
67 |
+
o || n(i), process.nextTick(Oe, r);
|
68 |
+
}), t && s.terminate();
|
69 |
+
}, r._final = function(i) {
|
70 |
+
if (s.readyState === s.CONNECTING) {
|
71 |
+
s.once("open", function() {
|
72 |
+
r._final(i);
|
73 |
+
});
|
74 |
+
return;
|
75 |
+
}
|
76 |
+
s._socket !== null && (s._socket._writableState.finished ? (i(), r._readableState.endEmitted && r.destroy()) : (s._socket.once("finish", function() {
|
77 |
+
i();
|
78 |
+
}), s.close()));
|
79 |
+
}, r._read = function() {
|
80 |
+
s.isPaused && s.resume();
|
81 |
+
}, r._write = function(i, n, o) {
|
82 |
+
if (s.readyState === s.CONNECTING) {
|
83 |
+
s.once("open", function() {
|
84 |
+
r._write(i, n, o);
|
85 |
+
});
|
86 |
+
return;
|
87 |
+
}
|
88 |
+
s.send(i, o);
|
89 |
+
}, r.on("end", vt), r.on("error", Qe), r;
|
90 |
+
}
|
91 |
+
var Et = St;
|
92 |
+
const Vs = /* @__PURE__ */ z(Et);
|
93 |
+
var te = { exports: {} }, U = {
|
94 |
+
BINARY_TYPES: ["nodebuffer", "arraybuffer", "fragments"],
|
95 |
+
EMPTY_BUFFER: Buffer.alloc(0),
|
96 |
+
GUID: "258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
|
97 |
+
kForOnEventAttribute: Symbol("kIsForOnEventAttribute"),
|
98 |
+
kListener: Symbol("kListener"),
|
99 |
+
kStatusCode: Symbol("status-code"),
|
100 |
+
kWebSocket: Symbol("websocket"),
|
101 |
+
NOOP: () => {
|
102 |
+
}
|
103 |
+
}, bt, xt;
|
104 |
+
const { EMPTY_BUFFER: kt } = U, Se = Buffer[Symbol.species];
|
105 |
+
function wt(s, e) {
|
106 |
+
if (s.length === 0)
|
107 |
+
return kt;
|
108 |
+
if (s.length === 1)
|
109 |
+
return s[0];
|
110 |
+
const t = Buffer.allocUnsafe(e);
|
111 |
+
let r = 0;
|
112 |
+
for (let i = 0; i < s.length; i++) {
|
113 |
+
const n = s[i];
|
114 |
+
t.set(n, r), r += n.length;
|
115 |
+
}
|
116 |
+
return r < e ? new Se(t.buffer, t.byteOffset, r) : t;
|
117 |
+
}
|
118 |
+
function Je(s, e, t, r, i) {
|
119 |
+
for (let n = 0; n < i; n++)
|
120 |
+
t[r + n] = s[n] ^ e[n & 3];
|
121 |
+
}
|
122 |
+
function et(s, e) {
|
123 |
+
for (let t = 0; t < s.length; t++)
|
124 |
+
s[t] ^= e[t & 3];
|
125 |
+
}
|
126 |
+
function Ot(s) {
|
127 |
+
return s.length === s.buffer.byteLength ? s.buffer : s.buffer.slice(s.byteOffset, s.byteOffset + s.length);
|
128 |
+
}
|
129 |
+
function Ee(s) {
|
130 |
+
if (Ee.readOnly = !0, Buffer.isBuffer(s))
|
131 |
+
return s;
|
132 |
+
let e;
|
133 |
+
return s instanceof ArrayBuffer ? e = new Se(s) : ArrayBuffer.isView(s) ? e = new Se(s.buffer, s.byteOffset, s.byteLength) : (e = Buffer.from(s), Ee.readOnly = !1), e;
|
134 |
+
}
|
135 |
+
te.exports = {
|
136 |
+
concat: wt,
|
137 |
+
mask: Je,
|
138 |
+
toArrayBuffer: Ot,
|
139 |
+
toBuffer: Ee,
|
140 |
+
unmask: et
|
141 |
+
};
|
142 |
+
if (!process.env.WS_NO_BUFFER_UTIL)
|
143 |
+
try {
|
144 |
+
const s = require("bufferutil");
|
145 |
+
xt = te.exports.mask = function(e, t, r, i, n) {
|
146 |
+
n < 48 ? Je(e, t, r, i, n) : s.mask(e, t, r, i, n);
|
147 |
+
}, bt = te.exports.unmask = function(e, t) {
|
148 |
+
e.length < 32 ? et(e, t) : s.unmask(e, t);
|
149 |
+
};
|
150 |
+
} catch {
|
151 |
+
}
|
152 |
+
var ne = te.exports;
|
153 |
+
const Ce = Symbol("kDone"), ue = Symbol("kRun");
|
154 |
+
let Ct = class {
|
155 |
+
/**
|
156 |
+
* Creates a new `Limiter`.
|
157 |
+
*
|
158 |
+
* @param {Number} [concurrency=Infinity] The maximum number of jobs allowed
|
159 |
+
* to run concurrently
|
160 |
+
*/
|
161 |
+
constructor(e) {
|
162 |
+
this[Ce] = () => {
|
163 |
+
this.pending--, this[ue]();
|
164 |
+
}, this.concurrency = e || 1 / 0, this.jobs = [], this.pending = 0;
|
165 |
+
}
|
166 |
+
/**
|
167 |
+
* Adds a job to the queue.
|
168 |
+
*
|
169 |
+
* @param {Function} job The job to run
|
170 |
+
* @public
|
171 |
+
*/
|
172 |
+
add(e) {
|
173 |
+
this.jobs.push(e), this[ue]();
|
174 |
+
}
|
175 |
+
/**
|
176 |
+
* Removes a job from the queue and runs it if possible.
|
177 |
+
*
|
178 |
+
* @private
|
179 |
+
*/
|
180 |
+
[ue]() {
|
181 |
+
if (this.pending !== this.concurrency && this.jobs.length) {
|
182 |
+
const e = this.jobs.shift();
|
183 |
+
this.pending++, e(this[Ce]);
|
184 |
+
}
|
185 |
+
}
|
186 |
+
};
|
187 |
+
var Tt = Ct;
|
188 |
+
const W = S, Te = ne, Lt = Tt, { kStatusCode: tt } = U, Nt = Buffer[Symbol.species], Pt = Buffer.from([0, 0, 255, 255]), se = Symbol("permessage-deflate"), w = Symbol("total-length"), V = Symbol("callback"), C = Symbol("buffers"), J = Symbol("error");
|
189 |
+
let K, Rt = class {
|
190 |
+
/**
|
191 |
+
* Creates a PerMessageDeflate instance.
|
192 |
+
*
|
193 |
+
* @param {Object} [options] Configuration options
|
194 |
+
* @param {(Boolean|Number)} [options.clientMaxWindowBits] Advertise support
|
195 |
+
* for, or request, a custom client window size
|
196 |
+
* @param {Boolean} [options.clientNoContextTakeover=false] Advertise/
|
197 |
+
* acknowledge disabling of client context takeover
|
198 |
+
* @param {Number} [options.concurrencyLimit=10] The number of concurrent
|
199 |
+
* calls to zlib
|
200 |
+
* @param {(Boolean|Number)} [options.serverMaxWindowBits] Request/confirm the
|
201 |
+
* use of a custom server window size
|
202 |
+
* @param {Boolean} [options.serverNoContextTakeover=false] Request/accept
|
203 |
+
* disabling of server context takeover
|
204 |
+
* @param {Number} [options.threshold=1024] Size (in bytes) below which
|
205 |
+
* messages should not be compressed if context takeover is disabled
|
206 |
+
* @param {Object} [options.zlibDeflateOptions] Options to pass to zlib on
|
207 |
+
* deflate
|
208 |
+
* @param {Object} [options.zlibInflateOptions] Options to pass to zlib on
|
209 |
+
* inflate
|
210 |
+
* @param {Boolean} [isServer=false] Create the instance in either server or
|
211 |
+
* client mode
|
212 |
+
* @param {Number} [maxPayload=0] The maximum allowed message length
|
213 |
+
*/
|
214 |
+
constructor(e, t, r) {
|
215 |
+
if (this._maxPayload = r | 0, this._options = e || {}, this._threshold = this._options.threshold !== void 0 ? this._options.threshold : 1024, this._isServer = !!t, this._deflate = null, this._inflate = null, this.params = null, !K) {
|
216 |
+
const i = this._options.concurrencyLimit !== void 0 ? this._options.concurrencyLimit : 10;
|
217 |
+
K = new Lt(i);
|
218 |
+
}
|
219 |
+
}
|
220 |
+
/**
|
221 |
+
* @type {String}
|
222 |
+
*/
|
223 |
+
static get extensionName() {
|
224 |
+
return "permessage-deflate";
|
225 |
+
}
|
226 |
+
/**
|
227 |
+
* Create an extension negotiation offer.
|
228 |
+
*
|
229 |
+
* @return {Object} Extension parameters
|
230 |
+
* @public
|
231 |
+
*/
|
232 |
+
offer() {
|
233 |
+
const e = {};
|
234 |
+
return this._options.serverNoContextTakeover && (e.server_no_context_takeover = !0), this._options.clientNoContextTakeover && (e.client_no_context_takeover = !0), this._options.serverMaxWindowBits && (e.server_max_window_bits = this._options.serverMaxWindowBits), this._options.clientMaxWindowBits ? e.client_max_window_bits = this._options.clientMaxWindowBits : this._options.clientMaxWindowBits == null && (e.client_max_window_bits = !0), e;
|
235 |
+
}
|
236 |
+
/**
|
237 |
+
* Accept an extension negotiation offer/response.
|
238 |
+
*
|
239 |
+
* @param {Array} configurations The extension negotiation offers/reponse
|
240 |
+
* @return {Object} Accepted configuration
|
241 |
+
* @public
|
242 |
+
*/
|
243 |
+
accept(e) {
|
244 |
+
return e = this.normalizeParams(e), this.params = this._isServer ? this.acceptAsServer(e) : this.acceptAsClient(e), this.params;
|
245 |
+
}
|
246 |
+
/**
|
247 |
+
* Releases all resources used by the extension.
|
248 |
+
*
|
249 |
+
* @public
|
250 |
+
*/
|
251 |
+
cleanup() {
|
252 |
+
if (this._inflate && (this._inflate.close(), this._inflate = null), this._deflate) {
|
253 |
+
const e = this._deflate[V];
|
254 |
+
this._deflate.close(), this._deflate = null, e && e(
|
255 |
+
new Error(
|
256 |
+
"The deflate stream was closed while data was being processed"
|
257 |
+
)
|
258 |
+
);
|
259 |
+
}
|
260 |
+
}
|
261 |
+
/**
|
262 |
+
* Accept an extension negotiation offer.
|
263 |
+
*
|
264 |
+
* @param {Array} offers The extension negotiation offers
|
265 |
+
* @return {Object} Accepted configuration
|
266 |
+
* @private
|
267 |
+
*/
|
268 |
+
acceptAsServer(e) {
|
269 |
+
const t = this._options, r = e.find((i) => !(t.serverNoContextTakeover === !1 && i.server_no_context_takeover || i.server_max_window_bits && (t.serverMaxWindowBits === !1 || typeof t.serverMaxWindowBits == "number" && t.serverMaxWindowBits > i.server_max_window_bits) || typeof t.clientMaxWindowBits == "number" && !i.client_max_window_bits));
|
270 |
+
if (!r)
|
271 |
+
throw new Error("None of the extension offers can be accepted");
|
272 |
+
return t.serverNoContextTakeover && (r.server_no_context_takeover = !0), t.clientNoContextTakeover && (r.client_no_context_takeover = !0), typeof t.serverMaxWindowBits == "number" && (r.server_max_window_bits = t.serverMaxWindowBits), typeof t.clientMaxWindowBits == "number" ? r.client_max_window_bits = t.clientMaxWindowBits : (r.client_max_window_bits === !0 || t.clientMaxWindowBits === !1) && delete r.client_max_window_bits, r;
|
273 |
+
}
|
274 |
+
/**
|
275 |
+
* Accept the extension negotiation response.
|
276 |
+
*
|
277 |
+
* @param {Array} response The extension negotiation response
|
278 |
+
* @return {Object} Accepted configuration
|
279 |
+
* @private
|
280 |
+
*/
|
281 |
+
acceptAsClient(e) {
|
282 |
+
const t = e[0];
|
283 |
+
if (this._options.clientNoContextTakeover === !1 && t.client_no_context_takeover)
|
284 |
+
throw new Error('Unexpected parameter "client_no_context_takeover"');
|
285 |
+
if (!t.client_max_window_bits)
|
286 |
+
typeof this._options.clientMaxWindowBits == "number" && (t.client_max_window_bits = this._options.clientMaxWindowBits);
|
287 |
+
else if (this._options.clientMaxWindowBits === !1 || typeof this._options.clientMaxWindowBits == "number" && t.client_max_window_bits > this._options.clientMaxWindowBits)
|
288 |
+
throw new Error(
|
289 |
+
'Unexpected or invalid parameter "client_max_window_bits"'
|
290 |
+
);
|
291 |
+
return t;
|
292 |
+
}
|
293 |
+
/**
|
294 |
+
* Normalize parameters.
|
295 |
+
*
|
296 |
+
* @param {Array} configurations The extension negotiation offers/reponse
|
297 |
+
* @return {Array} The offers/response with normalized parameters
|
298 |
+
* @private
|
299 |
+
*/
|
300 |
+
normalizeParams(e) {
|
301 |
+
return e.forEach((t) => {
|
302 |
+
Object.keys(t).forEach((r) => {
|
303 |
+
let i = t[r];
|
304 |
+
if (i.length > 1)
|
305 |
+
throw new Error(`Parameter "${r}" must have only a single value`);
|
306 |
+
if (i = i[0], r === "client_max_window_bits") {
|
307 |
+
if (i !== !0) {
|
308 |
+
const n = +i;
|
309 |
+
if (!Number.isInteger(n) || n < 8 || n > 15)
|
310 |
+
throw new TypeError(
|
311 |
+
`Invalid value for parameter "${r}": ${i}`
|
312 |
+
);
|
313 |
+
i = n;
|
314 |
+
} else if (!this._isServer)
|
315 |
+
throw new TypeError(
|
316 |
+
`Invalid value for parameter "${r}": ${i}`
|
317 |
+
);
|
318 |
+
} else if (r === "server_max_window_bits") {
|
319 |
+
const n = +i;
|
320 |
+
if (!Number.isInteger(n) || n < 8 || n > 15)
|
321 |
+
throw new TypeError(
|
322 |
+
`Invalid value for parameter "${r}": ${i}`
|
323 |
+
);
|
324 |
+
i = n;
|
325 |
+
} else if (r === "client_no_context_takeover" || r === "server_no_context_takeover") {
|
326 |
+
if (i !== !0)
|
327 |
+
throw new TypeError(
|
328 |
+
`Invalid value for parameter "${r}": ${i}`
|
329 |
+
);
|
330 |
+
} else
|
331 |
+
throw new Error(`Unknown parameter "${r}"`);
|
332 |
+
t[r] = i;
|
333 |
+
});
|
334 |
+
}), e;
|
335 |
+
}
|
336 |
+
/**
|
337 |
+
* Decompress data. Concurrency limited.
|
338 |
+
*
|
339 |
+
* @param {Buffer} data Compressed data
|
340 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
341 |
+
* @param {Function} callback Callback
|
342 |
+
* @public
|
343 |
+
*/
|
344 |
+
decompress(e, t, r) {
|
345 |
+
K.add((i) => {
|
346 |
+
this._decompress(e, t, (n, o) => {
|
347 |
+
i(), r(n, o);
|
348 |
+
});
|
349 |
+
});
|
350 |
+
}
|
351 |
+
/**
|
352 |
+
* Compress data. Concurrency limited.
|
353 |
+
*
|
354 |
+
* @param {(Buffer|String)} data Data to compress
|
355 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
356 |
+
* @param {Function} callback Callback
|
357 |
+
* @public
|
358 |
+
*/
|
359 |
+
compress(e, t, r) {
|
360 |
+
K.add((i) => {
|
361 |
+
this._compress(e, t, (n, o) => {
|
362 |
+
i(), r(n, o);
|
363 |
+
});
|
364 |
+
});
|
365 |
+
}
|
366 |
+
/**
|
367 |
+
* Decompress data.
|
368 |
+
*
|
369 |
+
* @param {Buffer} data Compressed data
|
370 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
371 |
+
* @param {Function} callback Callback
|
372 |
+
* @private
|
373 |
+
*/
|
374 |
+
_decompress(e, t, r) {
|
375 |
+
const i = this._isServer ? "client" : "server";
|
376 |
+
if (!this._inflate) {
|
377 |
+
const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
|
378 |
+
this._inflate = W.createInflateRaw({
|
379 |
+
...this._options.zlibInflateOptions,
|
380 |
+
windowBits: o
|
381 |
+
}), this._inflate[se] = this, this._inflate[w] = 0, this._inflate[C] = [], this._inflate.on("error", Bt), this._inflate.on("data", st);
|
382 |
+
}
|
383 |
+
this._inflate[V] = r, this._inflate.write(e), t && this._inflate.write(Pt), this._inflate.flush(() => {
|
384 |
+
const n = this._inflate[J];
|
385 |
+
if (n) {
|
386 |
+
this._inflate.close(), this._inflate = null, r(n);
|
387 |
+
return;
|
388 |
+
}
|
389 |
+
const o = Te.concat(
|
390 |
+
this._inflate[C],
|
391 |
+
this._inflate[w]
|
392 |
+
);
|
393 |
+
this._inflate._readableState.endEmitted ? (this._inflate.close(), this._inflate = null) : (this._inflate[w] = 0, this._inflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._inflate.reset()), r(null, o);
|
394 |
+
});
|
395 |
+
}
|
396 |
+
/**
|
397 |
+
* Compress data.
|
398 |
+
*
|
399 |
+
* @param {(Buffer|String)} data Data to compress
|
400 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
401 |
+
* @param {Function} callback Callback
|
402 |
+
* @private
|
403 |
+
*/
|
404 |
+
_compress(e, t, r) {
|
405 |
+
const i = this._isServer ? "server" : "client";
|
406 |
+
if (!this._deflate) {
|
407 |
+
const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
|
408 |
+
this._deflate = W.createDeflateRaw({
|
409 |
+
...this._options.zlibDeflateOptions,
|
410 |
+
windowBits: o
|
411 |
+
}), this._deflate[w] = 0, this._deflate[C] = [], this._deflate.on("data", Ut);
|
412 |
+
}
|
413 |
+
this._deflate[V] = r, this._deflate.write(e), this._deflate.flush(W.Z_SYNC_FLUSH, () => {
|
414 |
+
if (!this._deflate)
|
415 |
+
return;
|
416 |
+
let n = Te.concat(
|
417 |
+
this._deflate[C],
|
418 |
+
this._deflate[w]
|
419 |
+
);
|
420 |
+
t && (n = new Nt(n.buffer, n.byteOffset, n.length - 4)), this._deflate[V] = null, this._deflate[w] = 0, this._deflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._deflate.reset(), r(null, n);
|
421 |
+
});
|
422 |
+
}
|
423 |
+
};
|
424 |
+
var oe = Rt;
|
425 |
+
function Ut(s) {
|
426 |
+
this[C].push(s), this[w] += s.length;
|
427 |
+
}
|
428 |
+
function st(s) {
|
429 |
+
if (this[w] += s.length, this[se]._maxPayload < 1 || this[w] <= this[se]._maxPayload) {
|
430 |
+
this[C].push(s);
|
431 |
+
return;
|
432 |
+
}
|
433 |
+
this[J] = new RangeError("Max payload size exceeded"), this[J].code = "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH", this[J][tt] = 1009, this.removeListener("data", st), this.reset();
|
434 |
+
}
|
435 |
+
function Bt(s) {
|
436 |
+
this[se]._inflate = null, s[tt] = 1007, this[V](s);
|
437 |
+
}
|
438 |
+
var re = { exports: {} };
|
439 |
+
const $t = {}, Mt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
440 |
+
__proto__: null,
|
441 |
+
default: $t
|
442 |
+
}, Symbol.toStringTag, { value: "Module" })), It = /* @__PURE__ */ gt(Mt);
|
443 |
+
var Le;
|
444 |
+
const { isUtf8: Ne } = S, Dt = [
|
445 |
+
0,
|
446 |
+
0,
|
447 |
+
0,
|
448 |
+
0,
|
449 |
+
0,
|
450 |
+
0,
|
451 |
+
0,
|
452 |
+
0,
|
453 |
+
0,
|
454 |
+
0,
|
455 |
+
0,
|
456 |
+
0,
|
457 |
+
0,
|
458 |
+
0,
|
459 |
+
0,
|
460 |
+
0,
|
461 |
+
// 0 - 15
|
462 |
+
0,
|
463 |
+
0,
|
464 |
+
0,
|
465 |
+
0,
|
466 |
+
0,
|
467 |
+
0,
|
468 |
+
0,
|
469 |
+
0,
|
470 |
+
0,
|
471 |
+
0,
|
472 |
+
0,
|
473 |
+
0,
|
474 |
+
0,
|
475 |
+
0,
|
476 |
+
0,
|
477 |
+
0,
|
478 |
+
// 16 - 31
|
479 |
+
0,
|
480 |
+
1,
|
481 |
+
0,
|
482 |
+
1,
|
483 |
+
1,
|
484 |
+
1,
|
485 |
+
1,
|
486 |
+
1,
|
487 |
+
0,
|
488 |
+
0,
|
489 |
+
1,
|
490 |
+
1,
|
491 |
+
0,
|
492 |
+
1,
|
493 |
+
1,
|
494 |
+
0,
|
495 |
+
// 32 - 47
|
496 |
+
1,
|
497 |
+
1,
|
498 |
+
1,
|
499 |
+
1,
|
500 |
+
1,
|
501 |
+
1,
|
502 |
+
1,
|
503 |
+
1,
|
504 |
+
1,
|
505 |
+
1,
|
506 |
+
0,
|
507 |
+
0,
|
508 |
+
0,
|
509 |
+
0,
|
510 |
+
0,
|
511 |
+
0,
|
512 |
+
// 48 - 63
|
513 |
+
0,
|
514 |
+
1,
|
515 |
+
1,
|
516 |
+
1,
|
517 |
+
1,
|
518 |
+
1,
|
519 |
+
1,
|
520 |
+
1,
|
521 |
+
1,
|
522 |
+
1,
|
523 |
+
1,
|
524 |
+
1,
|
525 |
+
1,
|
526 |
+
1,
|
527 |
+
1,
|
528 |
+
1,
|
529 |
+
// 64 - 79
|
530 |
+
1,
|
531 |
+
1,
|
532 |
+
1,
|
533 |
+
1,
|
534 |
+
1,
|
535 |
+
1,
|
536 |
+
1,
|
537 |
+
1,
|
538 |
+
1,
|
539 |
+
1,
|
540 |
+
1,
|
541 |
+
0,
|
542 |
+
0,
|
543 |
+
0,
|
544 |
+
1,
|
545 |
+
1,
|
546 |
+
// 80 - 95
|
547 |
+
1,
|
548 |
+
1,
|
549 |
+
1,
|
550 |
+
1,
|
551 |
+
1,
|
552 |
+
1,
|
553 |
+
1,
|
554 |
+
1,
|
555 |
+
1,
|
556 |
+
1,
|
557 |
+
1,
|
558 |
+
1,
|
559 |
+
1,
|
560 |
+
1,
|
561 |
+
1,
|
562 |
+
1,
|
563 |
+
// 96 - 111
|
564 |
+
1,
|
565 |
+
1,
|
566 |
+
1,
|
567 |
+
1,
|
568 |
+
1,
|
569 |
+
1,
|
570 |
+
1,
|
571 |
+
1,
|
572 |
+
1,
|
573 |
+
1,
|
574 |
+
1,
|
575 |
+
0,
|
576 |
+
1,
|
577 |
+
0,
|
578 |
+
1,
|
579 |
+
0
|
580 |
+
// 112 - 127
|
581 |
+
];
|
582 |
+
function Wt(s) {
|
583 |
+
return s >= 1e3 && s <= 1014 && s !== 1004 && s !== 1005 && s !== 1006 || s >= 3e3 && s <= 4999;
|
584 |
+
}
|
585 |
+
function be(s) {
|
586 |
+
const e = s.length;
|
587 |
+
let t = 0;
|
588 |
+
for (; t < e; )
|
589 |
+
if (!(s[t] & 128))
|
590 |
+
t++;
|
591 |
+
else if ((s[t] & 224) === 192) {
|
592 |
+
if (t + 1 === e || (s[t + 1] & 192) !== 128 || (s[t] & 254) === 192)
|
593 |
+
return !1;
|
594 |
+
t += 2;
|
595 |
+
} else if ((s[t] & 240) === 224) {
|
596 |
+
if (t + 2 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || s[t] === 224 && (s[t + 1] & 224) === 128 || // Overlong
|
597 |
+
s[t] === 237 && (s[t + 1] & 224) === 160)
|
598 |
+
return !1;
|
599 |
+
t += 3;
|
600 |
+
} else if ((s[t] & 248) === 240) {
|
601 |
+
if (t + 3 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || (s[t + 3] & 192) !== 128 || s[t] === 240 && (s[t + 1] & 240) === 128 || // Overlong
|
602 |
+
s[t] === 244 && s[t + 1] > 143 || s[t] > 244)
|
603 |
+
return !1;
|
604 |
+
t += 4;
|
605 |
+
} else
|
606 |
+
return !1;
|
607 |
+
return !0;
|
608 |
+
}
|
609 |
+
re.exports = {
|
610 |
+
isValidStatusCode: Wt,
|
611 |
+
isValidUTF8: be,
|
612 |
+
tokenChars: Dt
|
613 |
+
};
|
614 |
+
if (Ne)
|
615 |
+
Le = re.exports.isValidUTF8 = function(s) {
|
616 |
+
return s.length < 24 ? be(s) : Ne(s);
|
617 |
+
};
|
618 |
+
else if (!process.env.WS_NO_UTF_8_VALIDATE)
|
619 |
+
try {
|
620 |
+
const s = It;
|
621 |
+
Le = re.exports.isValidUTF8 = function(e) {
|
622 |
+
return e.length < 32 ? be(e) : s(e);
|
623 |
+
};
|
624 |
+
} catch {
|
625 |
+
}
|
626 |
+
var ae = re.exports;
|
627 |
+
const { Writable: At } = S, Pe = oe, {
|
628 |
+
BINARY_TYPES: Ft,
|
629 |
+
EMPTY_BUFFER: Re,
|
630 |
+
kStatusCode: jt,
|
631 |
+
kWebSocket: Gt
|
632 |
+
} = U, { concat: de, toArrayBuffer: Vt, unmask: Ht } = ne, { isValidStatusCode: zt, isValidUTF8: Ue } = ae, X = Buffer[Symbol.species], A = 0, Be = 1, $e = 2, Me = 3, _e = 4, Yt = 5;
|
633 |
+
let qt = class extends At {
|
634 |
+
/**
|
635 |
+
* Creates a Receiver instance.
|
636 |
+
*
|
637 |
+
* @param {Object} [options] Options object
|
638 |
+
* @param {String} [options.binaryType=nodebuffer] The type for binary data
|
639 |
+
* @param {Object} [options.extensions] An object containing the negotiated
|
640 |
+
* extensions
|
641 |
+
* @param {Boolean} [options.isServer=false] Specifies whether to operate in
|
642 |
+
* client or server mode
|
643 |
+
* @param {Number} [options.maxPayload=0] The maximum allowed message length
|
644 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
645 |
+
* not to skip UTF-8 validation for text and close messages
|
646 |
+
*/
|
647 |
+
constructor(e = {}) {
|
648 |
+
super(), this._binaryType = e.binaryType || Ft[0], this._extensions = e.extensions || {}, this._isServer = !!e.isServer, this._maxPayload = e.maxPayload | 0, this._skipUTF8Validation = !!e.skipUTF8Validation, this[Gt] = void 0, this._bufferedBytes = 0, this._buffers = [], this._compressed = !1, this._payloadLength = 0, this._mask = void 0, this._fragmented = 0, this._masked = !1, this._fin = !1, this._opcode = 0, this._totalPayloadLength = 0, this._messageLength = 0, this._fragments = [], this._state = A, this._loop = !1;
|
649 |
+
}
|
650 |
+
/**
|
651 |
+
* Implements `Writable.prototype._write()`.
|
652 |
+
*
|
653 |
+
* @param {Buffer} chunk The chunk of data to write
|
654 |
+
* @param {String} encoding The character encoding of `chunk`
|
655 |
+
* @param {Function} cb Callback
|
656 |
+
* @private
|
657 |
+
*/
|
658 |
+
_write(e, t, r) {
|
659 |
+
if (this._opcode === 8 && this._state == A)
|
660 |
+
return r();
|
661 |
+
this._bufferedBytes += e.length, this._buffers.push(e), this.startLoop(r);
|
662 |
+
}
|
663 |
+
/**
|
664 |
+
* Consumes `n` bytes from the buffered data.
|
665 |
+
*
|
666 |
+
* @param {Number} n The number of bytes to consume
|
667 |
+
* @return {Buffer} The consumed bytes
|
668 |
+
* @private
|
669 |
+
*/
|
670 |
+
consume(e) {
|
671 |
+
if (this._bufferedBytes -= e, e === this._buffers[0].length)
|
672 |
+
return this._buffers.shift();
|
673 |
+
if (e < this._buffers[0].length) {
|
674 |
+
const r = this._buffers[0];
|
675 |
+
return this._buffers[0] = new X(
|
676 |
+
r.buffer,
|
677 |
+
r.byteOffset + e,
|
678 |
+
r.length - e
|
679 |
+
), new X(r.buffer, r.byteOffset, e);
|
680 |
+
}
|
681 |
+
const t = Buffer.allocUnsafe(e);
|
682 |
+
do {
|
683 |
+
const r = this._buffers[0], i = t.length - e;
|
684 |
+
e >= r.length ? t.set(this._buffers.shift(), i) : (t.set(new Uint8Array(r.buffer, r.byteOffset, e), i), this._buffers[0] = new X(
|
685 |
+
r.buffer,
|
686 |
+
r.byteOffset + e,
|
687 |
+
r.length - e
|
688 |
+
)), e -= r.length;
|
689 |
+
} while (e > 0);
|
690 |
+
return t;
|
691 |
+
}
|
692 |
+
/**
|
693 |
+
* Starts the parsing loop.
|
694 |
+
*
|
695 |
+
* @param {Function} cb Callback
|
696 |
+
* @private
|
697 |
+
*/
|
698 |
+
startLoop(e) {
|
699 |
+
let t;
|
700 |
+
this._loop = !0;
|
701 |
+
do
|
702 |
+
switch (this._state) {
|
703 |
+
case A:
|
704 |
+
t = this.getInfo();
|
705 |
+
break;
|
706 |
+
case Be:
|
707 |
+
t = this.getPayloadLength16();
|
708 |
+
break;
|
709 |
+
case $e:
|
710 |
+
t = this.getPayloadLength64();
|
711 |
+
break;
|
712 |
+
case Me:
|
713 |
+
this.getMask();
|
714 |
+
break;
|
715 |
+
case _e:
|
716 |
+
t = this.getData(e);
|
717 |
+
break;
|
718 |
+
default:
|
719 |
+
this._loop = !1;
|
720 |
+
return;
|
721 |
+
}
|
722 |
+
while (this._loop);
|
723 |
+
e(t);
|
724 |
+
}
|
725 |
+
/**
|
726 |
+
* Reads the first two bytes of a frame.
|
727 |
+
*
|
728 |
+
* @return {(RangeError|undefined)} A possible error
|
729 |
+
* @private
|
730 |
+
*/
|
731 |
+
getInfo() {
|
732 |
+
if (this._bufferedBytes < 2) {
|
733 |
+
this._loop = !1;
|
734 |
+
return;
|
735 |
+
}
|
736 |
+
const e = this.consume(2);
|
737 |
+
if (e[0] & 48)
|
738 |
+
return this._loop = !1, g(
|
739 |
+
RangeError,
|
740 |
+
"RSV2 and RSV3 must be clear",
|
741 |
+
!0,
|
742 |
+
1002,
|
743 |
+
"WS_ERR_UNEXPECTED_RSV_2_3"
|
744 |
+
);
|
745 |
+
const t = (e[0] & 64) === 64;
|
746 |
+
if (t && !this._extensions[Pe.extensionName])
|
747 |
+
return this._loop = !1, g(
|
748 |
+
RangeError,
|
749 |
+
"RSV1 must be clear",
|
750 |
+
!0,
|
751 |
+
1002,
|
752 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
753 |
+
);
|
754 |
+
if (this._fin = (e[0] & 128) === 128, this._opcode = e[0] & 15, this._payloadLength = e[1] & 127, this._opcode === 0) {
|
755 |
+
if (t)
|
756 |
+
return this._loop = !1, g(
|
757 |
+
RangeError,
|
758 |
+
"RSV1 must be clear",
|
759 |
+
!0,
|
760 |
+
1002,
|
761 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
762 |
+
);
|
763 |
+
if (!this._fragmented)
|
764 |
+
return this._loop = !1, g(
|
765 |
+
RangeError,
|
766 |
+
"invalid opcode 0",
|
767 |
+
!0,
|
768 |
+
1002,
|
769 |
+
"WS_ERR_INVALID_OPCODE"
|
770 |
+
);
|
771 |
+
this._opcode = this._fragmented;
|
772 |
+
} else if (this._opcode === 1 || this._opcode === 2) {
|
773 |
+
if (this._fragmented)
|
774 |
+
return this._loop = !1, g(
|
775 |
+
RangeError,
|
776 |
+
`invalid opcode ${this._opcode}`,
|
777 |
+
!0,
|
778 |
+
1002,
|
779 |
+
"WS_ERR_INVALID_OPCODE"
|
780 |
+
);
|
781 |
+
this._compressed = t;
|
782 |
+
} else if (this._opcode > 7 && this._opcode < 11) {
|
783 |
+
if (!this._fin)
|
784 |
+
return this._loop = !1, g(
|
785 |
+
RangeError,
|
786 |
+
"FIN must be set",
|
787 |
+
!0,
|
788 |
+
1002,
|
789 |
+
"WS_ERR_EXPECTED_FIN"
|
790 |
+
);
|
791 |
+
if (t)
|
792 |
+
return this._loop = !1, g(
|
793 |
+
RangeError,
|
794 |
+
"RSV1 must be clear",
|
795 |
+
!0,
|
796 |
+
1002,
|
797 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
798 |
+
);
|
799 |
+
if (this._payloadLength > 125 || this._opcode === 8 && this._payloadLength === 1)
|
800 |
+
return this._loop = !1, g(
|
801 |
+
RangeError,
|
802 |
+
`invalid payload length ${this._payloadLength}`,
|
803 |
+
!0,
|
804 |
+
1002,
|
805 |
+
"WS_ERR_INVALID_CONTROL_PAYLOAD_LENGTH"
|
806 |
+
);
|
807 |
+
} else
|
808 |
+
return this._loop = !1, g(
|
809 |
+
RangeError,
|
810 |
+
`invalid opcode ${this._opcode}`,
|
811 |
+
!0,
|
812 |
+
1002,
|
813 |
+
"WS_ERR_INVALID_OPCODE"
|
814 |
+
);
|
815 |
+
if (!this._fin && !this._fragmented && (this._fragmented = this._opcode), this._masked = (e[1] & 128) === 128, this._isServer) {
|
816 |
+
if (!this._masked)
|
817 |
+
return this._loop = !1, g(
|
818 |
+
RangeError,
|
819 |
+
"MASK must be set",
|
820 |
+
!0,
|
821 |
+
1002,
|
822 |
+
"WS_ERR_EXPECTED_MASK"
|
823 |
+
);
|
824 |
+
} else if (this._masked)
|
825 |
+
return this._loop = !1, g(
|
826 |
+
RangeError,
|
827 |
+
"MASK must be clear",
|
828 |
+
!0,
|
829 |
+
1002,
|
830 |
+
"WS_ERR_UNEXPECTED_MASK"
|
831 |
+
);
|
832 |
+
if (this._payloadLength === 126)
|
833 |
+
this._state = Be;
|
834 |
+
else if (this._payloadLength === 127)
|
835 |
+
this._state = $e;
|
836 |
+
else
|
837 |
+
return this.haveLength();
|
838 |
+
}
|
839 |
+
/**
|
840 |
+
* Gets extended payload length (7+16).
|
841 |
+
*
|
842 |
+
* @return {(RangeError|undefined)} A possible error
|
843 |
+
* @private
|
844 |
+
*/
|
845 |
+
getPayloadLength16() {
|
846 |
+
if (this._bufferedBytes < 2) {
|
847 |
+
this._loop = !1;
|
848 |
+
return;
|
849 |
+
}
|
850 |
+
return this._payloadLength = this.consume(2).readUInt16BE(0), this.haveLength();
|
851 |
+
}
|
852 |
+
/**
|
853 |
+
* Gets extended payload length (7+64).
|
854 |
+
*
|
855 |
+
* @return {(RangeError|undefined)} A possible error
|
856 |
+
* @private
|
857 |
+
*/
|
858 |
+
getPayloadLength64() {
|
859 |
+
if (this._bufferedBytes < 8) {
|
860 |
+
this._loop = !1;
|
861 |
+
return;
|
862 |
+
}
|
863 |
+
const e = this.consume(8), t = e.readUInt32BE(0);
|
864 |
+
return t > Math.pow(2, 21) - 1 ? (this._loop = !1, g(
|
865 |
+
RangeError,
|
866 |
+
"Unsupported WebSocket frame: payload length > 2^53 - 1",
|
867 |
+
!1,
|
868 |
+
1009,
|
869 |
+
"WS_ERR_UNSUPPORTED_DATA_PAYLOAD_LENGTH"
|
870 |
+
)) : (this._payloadLength = t * Math.pow(2, 32) + e.readUInt32BE(4), this.haveLength());
|
871 |
+
}
|
872 |
+
/**
|
873 |
+
* Payload length has been read.
|
874 |
+
*
|
875 |
+
* @return {(RangeError|undefined)} A possible error
|
876 |
+
* @private
|
877 |
+
*/
|
878 |
+
haveLength() {
|
879 |
+
if (this._payloadLength && this._opcode < 8 && (this._totalPayloadLength += this._payloadLength, this._totalPayloadLength > this._maxPayload && this._maxPayload > 0))
|
880 |
+
return this._loop = !1, g(
|
881 |
+
RangeError,
|
882 |
+
"Max payload size exceeded",
|
883 |
+
!1,
|
884 |
+
1009,
|
885 |
+
"WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
|
886 |
+
);
|
887 |
+
this._masked ? this._state = Me : this._state = _e;
|
888 |
+
}
|
889 |
+
/**
|
890 |
+
* Reads mask bytes.
|
891 |
+
*
|
892 |
+
* @private
|
893 |
+
*/
|
894 |
+
getMask() {
|
895 |
+
if (this._bufferedBytes < 4) {
|
896 |
+
this._loop = !1;
|
897 |
+
return;
|
898 |
+
}
|
899 |
+
this._mask = this.consume(4), this._state = _e;
|
900 |
+
}
|
901 |
+
/**
|
902 |
+
* Reads data bytes.
|
903 |
+
*
|
904 |
+
* @param {Function} cb Callback
|
905 |
+
* @return {(Error|RangeError|undefined)} A possible error
|
906 |
+
* @private
|
907 |
+
*/
|
908 |
+
getData(e) {
|
909 |
+
let t = Re;
|
910 |
+
if (this._payloadLength) {
|
911 |
+
if (this._bufferedBytes < this._payloadLength) {
|
912 |
+
this._loop = !1;
|
913 |
+
return;
|
914 |
+
}
|
915 |
+
t = this.consume(this._payloadLength), this._masked && this._mask[0] | this._mask[1] | this._mask[2] | this._mask[3] && Ht(t, this._mask);
|
916 |
+
}
|
917 |
+
if (this._opcode > 7)
|
918 |
+
return this.controlMessage(t);
|
919 |
+
if (this._compressed) {
|
920 |
+
this._state = Yt, this.decompress(t, e);
|
921 |
+
return;
|
922 |
+
}
|
923 |
+
return t.length && (this._messageLength = this._totalPayloadLength, this._fragments.push(t)), this.dataMessage();
|
924 |
+
}
|
925 |
+
/**
|
926 |
+
* Decompresses data.
|
927 |
+
*
|
928 |
+
* @param {Buffer} data Compressed data
|
929 |
+
* @param {Function} cb Callback
|
930 |
+
* @private
|
931 |
+
*/
|
932 |
+
decompress(e, t) {
|
933 |
+
this._extensions[Pe.extensionName].decompress(e, this._fin, (i, n) => {
|
934 |
+
if (i)
|
935 |
+
return t(i);
|
936 |
+
if (n.length) {
|
937 |
+
if (this._messageLength += n.length, this._messageLength > this._maxPayload && this._maxPayload > 0)
|
938 |
+
return t(
|
939 |
+
g(
|
940 |
+
RangeError,
|
941 |
+
"Max payload size exceeded",
|
942 |
+
!1,
|
943 |
+
1009,
|
944 |
+
"WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
|
945 |
+
)
|
946 |
+
);
|
947 |
+
this._fragments.push(n);
|
948 |
+
}
|
949 |
+
const o = this.dataMessage();
|
950 |
+
if (o)
|
951 |
+
return t(o);
|
952 |
+
this.startLoop(t);
|
953 |
+
});
|
954 |
+
}
|
955 |
+
/**
|
956 |
+
* Handles a data message.
|
957 |
+
*
|
958 |
+
* @return {(Error|undefined)} A possible error
|
959 |
+
* @private
|
960 |
+
*/
|
961 |
+
dataMessage() {
|
962 |
+
if (this._fin) {
|
963 |
+
const e = this._messageLength, t = this._fragments;
|
964 |
+
if (this._totalPayloadLength = 0, this._messageLength = 0, this._fragmented = 0, this._fragments = [], this._opcode === 2) {
|
965 |
+
let r;
|
966 |
+
this._binaryType === "nodebuffer" ? r = de(t, e) : this._binaryType === "arraybuffer" ? r = Vt(de(t, e)) : r = t, this.emit("message", r, !0);
|
967 |
+
} else {
|
968 |
+
const r = de(t, e);
|
969 |
+
if (!this._skipUTF8Validation && !Ue(r))
|
970 |
+
return this._loop = !1, g(
|
971 |
+
Error,
|
972 |
+
"invalid UTF-8 sequence",
|
973 |
+
!0,
|
974 |
+
1007,
|
975 |
+
"WS_ERR_INVALID_UTF8"
|
976 |
+
);
|
977 |
+
this.emit("message", r, !1);
|
978 |
+
}
|
979 |
+
}
|
980 |
+
this._state = A;
|
981 |
+
}
|
982 |
+
/**
|
983 |
+
* Handles a control message.
|
984 |
+
*
|
985 |
+
* @param {Buffer} data Data to handle
|
986 |
+
* @return {(Error|RangeError|undefined)} A possible error
|
987 |
+
* @private
|
988 |
+
*/
|
989 |
+
controlMessage(e) {
|
990 |
+
if (this._opcode === 8)
|
991 |
+
if (this._loop = !1, e.length === 0)
|
992 |
+
this.emit("conclude", 1005, Re), this.end();
|
993 |
+
else {
|
994 |
+
const t = e.readUInt16BE(0);
|
995 |
+
if (!zt(t))
|
996 |
+
return g(
|
997 |
+
RangeError,
|
998 |
+
`invalid status code ${t}`,
|
999 |
+
!0,
|
1000 |
+
1002,
|
1001 |
+
"WS_ERR_INVALID_CLOSE_CODE"
|
1002 |
+
);
|
1003 |
+
const r = new X(
|
1004 |
+
e.buffer,
|
1005 |
+
e.byteOffset + 2,
|
1006 |
+
e.length - 2
|
1007 |
+
);
|
1008 |
+
if (!this._skipUTF8Validation && !Ue(r))
|
1009 |
+
return g(
|
1010 |
+
Error,
|
1011 |
+
"invalid UTF-8 sequence",
|
1012 |
+
!0,
|
1013 |
+
1007,
|
1014 |
+
"WS_ERR_INVALID_UTF8"
|
1015 |
+
);
|
1016 |
+
this.emit("conclude", t, r), this.end();
|
1017 |
+
}
|
1018 |
+
else
|
1019 |
+
this._opcode === 9 ? this.emit("ping", e) : this.emit("pong", e);
|
1020 |
+
this._state = A;
|
1021 |
+
}
|
1022 |
+
};
|
1023 |
+
var rt = qt;
|
1024 |
+
function g(s, e, t, r, i) {
|
1025 |
+
const n = new s(
|
1026 |
+
t ? `Invalid WebSocket frame: ${e}` : e
|
1027 |
+
);
|
1028 |
+
return Error.captureStackTrace(n, g), n.code = i, n[jt] = r, n;
|
1029 |
+
}
|
1030 |
+
const qs = /* @__PURE__ */ z(rt), { randomFillSync: Kt } = S, Ie = oe, { EMPTY_BUFFER: Xt } = U, { isValidStatusCode: Zt } = ae, { mask: De, toBuffer: M } = ne, x = Symbol("kByteLength"), Qt = Buffer.alloc(4);
|
1031 |
+
let Jt = class P {
|
1032 |
+
/**
|
1033 |
+
* Creates a Sender instance.
|
1034 |
+
*
|
1035 |
+
* @param {(net.Socket|tls.Socket)} socket The connection socket
|
1036 |
+
* @param {Object} [extensions] An object containing the negotiated extensions
|
1037 |
+
* @param {Function} [generateMask] The function used to generate the masking
|
1038 |
+
* key
|
1039 |
+
*/
|
1040 |
+
constructor(e, t, r) {
|
1041 |
+
this._extensions = t || {}, r && (this._generateMask = r, this._maskBuffer = Buffer.alloc(4)), this._socket = e, this._firstFragment = !0, this._compress = !1, this._bufferedBytes = 0, this._deflating = !1, this._queue = [];
|
1042 |
+
}
|
1043 |
+
/**
|
1044 |
+
* Frames a piece of data according to the HyBi WebSocket protocol.
|
1045 |
+
*
|
1046 |
+
* @param {(Buffer|String)} data The data to frame
|
1047 |
+
* @param {Object} options Options object
|
1048 |
+
* @param {Boolean} [options.fin=false] Specifies whether or not to set the
|
1049 |
+
* FIN bit
|
1050 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1051 |
+
* masking key
|
1052 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1053 |
+
* `data`
|
1054 |
+
* @param {Buffer} [options.maskBuffer] The buffer used to store the masking
|
1055 |
+
* key
|
1056 |
+
* @param {Number} options.opcode The opcode
|
1057 |
+
* @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
|
1058 |
+
* modified
|
1059 |
+
* @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
|
1060 |
+
* RSV1 bit
|
1061 |
+
* @return {(Buffer|String)[]} The framed data
|
1062 |
+
* @public
|
1063 |
+
*/
|
1064 |
+
static frame(e, t) {
|
1065 |
+
let r, i = !1, n = 2, o = !1;
|
1066 |
+
t.mask && (r = t.maskBuffer || Qt, t.generateMask ? t.generateMask(r) : Kt(r, 0, 4), o = (r[0] | r[1] | r[2] | r[3]) === 0, n = 6);
|
1067 |
+
let l;
|
1068 |
+
typeof e == "string" ? (!t.mask || o) && t[x] !== void 0 ? l = t[x] : (e = Buffer.from(e), l = e.length) : (l = e.length, i = t.mask && t.readOnly && !o);
|
1069 |
+
let f = l;
|
1070 |
+
l >= 65536 ? (n += 8, f = 127) : l > 125 && (n += 2, f = 126);
|
1071 |
+
const a = Buffer.allocUnsafe(i ? l + n : n);
|
1072 |
+
return a[0] = t.fin ? t.opcode | 128 : t.opcode, t.rsv1 && (a[0] |= 64), a[1] = f, f === 126 ? a.writeUInt16BE(l, 2) : f === 127 && (a[2] = a[3] = 0, a.writeUIntBE(l, 4, 6)), t.mask ? (a[1] |= 128, a[n - 4] = r[0], a[n - 3] = r[1], a[n - 2] = r[2], a[n - 1] = r[3], o ? [a, e] : i ? (De(e, r, a, n, l), [a]) : (De(e, r, e, 0, l), [a, e])) : [a, e];
|
1073 |
+
}
|
1074 |
+
/**
|
1075 |
+
* Sends a close message to the other peer.
|
1076 |
+
*
|
1077 |
+
* @param {Number} [code] The status code component of the body
|
1078 |
+
* @param {(String|Buffer)} [data] The message component of the body
|
1079 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask the message
|
1080 |
+
* @param {Function} [cb] Callback
|
1081 |
+
* @public
|
1082 |
+
*/
|
1083 |
+
close(e, t, r, i) {
|
1084 |
+
let n;
|
1085 |
+
if (e === void 0)
|
1086 |
+
n = Xt;
|
1087 |
+
else {
|
1088 |
+
if (typeof e != "number" || !Zt(e))
|
1089 |
+
throw new TypeError("First argument must be a valid error code number");
|
1090 |
+
if (t === void 0 || !t.length)
|
1091 |
+
n = Buffer.allocUnsafe(2), n.writeUInt16BE(e, 0);
|
1092 |
+
else {
|
1093 |
+
const l = Buffer.byteLength(t);
|
1094 |
+
if (l > 123)
|
1095 |
+
throw new RangeError("The message must not be greater than 123 bytes");
|
1096 |
+
n = Buffer.allocUnsafe(2 + l), n.writeUInt16BE(e, 0), typeof t == "string" ? n.write(t, 2) : n.set(t, 2);
|
1097 |
+
}
|
1098 |
+
}
|
1099 |
+
const o = {
|
1100 |
+
[x]: n.length,
|
1101 |
+
fin: !0,
|
1102 |
+
generateMask: this._generateMask,
|
1103 |
+
mask: r,
|
1104 |
+
maskBuffer: this._maskBuffer,
|
1105 |
+
opcode: 8,
|
1106 |
+
readOnly: !1,
|
1107 |
+
rsv1: !1
|
1108 |
+
};
|
1109 |
+
this._deflating ? this.enqueue([this.dispatch, n, !1, o, i]) : this.sendFrame(P.frame(n, o), i);
|
1110 |
+
}
|
1111 |
+
/**
|
1112 |
+
* Sends a ping message to the other peer.
|
1113 |
+
*
|
1114 |
+
* @param {*} data The message to send
|
1115 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask `data`
|
1116 |
+
* @param {Function} [cb] Callback
|
1117 |
+
* @public
|
1118 |
+
*/
|
1119 |
+
ping(e, t, r) {
|
1120 |
+
let i, n;
|
1121 |
+
if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
|
1122 |
+
throw new RangeError("The data size must not be greater than 125 bytes");
|
1123 |
+
const o = {
|
1124 |
+
[x]: i,
|
1125 |
+
fin: !0,
|
1126 |
+
generateMask: this._generateMask,
|
1127 |
+
mask: t,
|
1128 |
+
maskBuffer: this._maskBuffer,
|
1129 |
+
opcode: 9,
|
1130 |
+
readOnly: n,
|
1131 |
+
rsv1: !1
|
1132 |
+
};
|
1133 |
+
this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
|
1134 |
+
}
|
1135 |
+
/**
|
1136 |
+
* Sends a pong message to the other peer.
|
1137 |
+
*
|
1138 |
+
* @param {*} data The message to send
|
1139 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask `data`
|
1140 |
+
* @param {Function} [cb] Callback
|
1141 |
+
* @public
|
1142 |
+
*/
|
1143 |
+
pong(e, t, r) {
|
1144 |
+
let i, n;
|
1145 |
+
if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
|
1146 |
+
throw new RangeError("The data size must not be greater than 125 bytes");
|
1147 |
+
const o = {
|
1148 |
+
[x]: i,
|
1149 |
+
fin: !0,
|
1150 |
+
generateMask: this._generateMask,
|
1151 |
+
mask: t,
|
1152 |
+
maskBuffer: this._maskBuffer,
|
1153 |
+
opcode: 10,
|
1154 |
+
readOnly: n,
|
1155 |
+
rsv1: !1
|
1156 |
+
};
|
1157 |
+
this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
|
1158 |
+
}
|
1159 |
+
/**
|
1160 |
+
* Sends a data message to the other peer.
|
1161 |
+
*
|
1162 |
+
* @param {*} data The message to send
|
1163 |
+
* @param {Object} options Options object
|
1164 |
+
* @param {Boolean} [options.binary=false] Specifies whether `data` is binary
|
1165 |
+
* or text
|
1166 |
+
* @param {Boolean} [options.compress=false] Specifies whether or not to
|
1167 |
+
* compress `data`
|
1168 |
+
* @param {Boolean} [options.fin=false] Specifies whether the fragment is the
|
1169 |
+
* last one
|
1170 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1171 |
+
* `data`
|
1172 |
+
* @param {Function} [cb] Callback
|
1173 |
+
* @public
|
1174 |
+
*/
|
1175 |
+
send(e, t, r) {
|
1176 |
+
const i = this._extensions[Ie.extensionName];
|
1177 |
+
let n = t.binary ? 2 : 1, o = t.compress, l, f;
|
1178 |
+
if (typeof e == "string" ? (l = Buffer.byteLength(e), f = !1) : (e = M(e), l = e.length, f = M.readOnly), this._firstFragment ? (this._firstFragment = !1, o && i && i.params[i._isServer ? "server_no_context_takeover" : "client_no_context_takeover"] && (o = l >= i._threshold), this._compress = o) : (o = !1, n = 0), t.fin && (this._firstFragment = !0), i) {
|
1179 |
+
const a = {
|
1180 |
+
[x]: l,
|
1181 |
+
fin: t.fin,
|
1182 |
+
generateMask: this._generateMask,
|
1183 |
+
mask: t.mask,
|
1184 |
+
maskBuffer: this._maskBuffer,
|
1185 |
+
opcode: n,
|
1186 |
+
readOnly: f,
|
1187 |
+
rsv1: o
|
1188 |
+
};
|
1189 |
+
this._deflating ? this.enqueue([this.dispatch, e, this._compress, a, r]) : this.dispatch(e, this._compress, a, r);
|
1190 |
+
} else
|
1191 |
+
this.sendFrame(
|
1192 |
+
P.frame(e, {
|
1193 |
+
[x]: l,
|
1194 |
+
fin: t.fin,
|
1195 |
+
generateMask: this._generateMask,
|
1196 |
+
mask: t.mask,
|
1197 |
+
maskBuffer: this._maskBuffer,
|
1198 |
+
opcode: n,
|
1199 |
+
readOnly: f,
|
1200 |
+
rsv1: !1
|
1201 |
+
}),
|
1202 |
+
r
|
1203 |
+
);
|
1204 |
+
}
|
1205 |
+
/**
|
1206 |
+
* Dispatches a message.
|
1207 |
+
*
|
1208 |
+
* @param {(Buffer|String)} data The message to send
|
1209 |
+
* @param {Boolean} [compress=false] Specifies whether or not to compress
|
1210 |
+
* `data`
|
1211 |
+
* @param {Object} options Options object
|
1212 |
+
* @param {Boolean} [options.fin=false] Specifies whether or not to set the
|
1213 |
+
* FIN bit
|
1214 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1215 |
+
* masking key
|
1216 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1217 |
+
* `data`
|
1218 |
+
* @param {Buffer} [options.maskBuffer] The buffer used to store the masking
|
1219 |
+
* key
|
1220 |
+
* @param {Number} options.opcode The opcode
|
1221 |
+
* @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
|
1222 |
+
* modified
|
1223 |
+
* @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
|
1224 |
+
* RSV1 bit
|
1225 |
+
* @param {Function} [cb] Callback
|
1226 |
+
* @private
|
1227 |
+
*/
|
1228 |
+
dispatch(e, t, r, i) {
|
1229 |
+
if (!t) {
|
1230 |
+
this.sendFrame(P.frame(e, r), i);
|
1231 |
+
return;
|
1232 |
+
}
|
1233 |
+
const n = this._extensions[Ie.extensionName];
|
1234 |
+
this._bufferedBytes += r[x], this._deflating = !0, n.compress(e, r.fin, (o, l) => {
|
1235 |
+
if (this._socket.destroyed) {
|
1236 |
+
const f = new Error(
|
1237 |
+
"The socket was closed while data was being compressed"
|
1238 |
+
);
|
1239 |
+
typeof i == "function" && i(f);
|
1240 |
+
for (let a = 0; a < this._queue.length; a++) {
|
1241 |
+
const c = this._queue[a], h = c[c.length - 1];
|
1242 |
+
typeof h == "function" && h(f);
|
1243 |
+
}
|
1244 |
+
return;
|
1245 |
+
}
|
1246 |
+
this._bufferedBytes -= r[x], this._deflating = !1, r.readOnly = !1, this.sendFrame(P.frame(l, r), i), this.dequeue();
|
1247 |
+
});
|
1248 |
+
}
|
1249 |
+
/**
|
1250 |
+
* Executes queued send operations.
|
1251 |
+
*
|
1252 |
+
* @private
|
1253 |
+
*/
|
1254 |
+
dequeue() {
|
1255 |
+
for (; !this._deflating && this._queue.length; ) {
|
1256 |
+
const e = this._queue.shift();
|
1257 |
+
this._bufferedBytes -= e[3][x], Reflect.apply(e[0], this, e.slice(1));
|
1258 |
+
}
|
1259 |
+
}
|
1260 |
+
/**
|
1261 |
+
* Enqueues a send operation.
|
1262 |
+
*
|
1263 |
+
* @param {Array} params Send operation parameters.
|
1264 |
+
* @private
|
1265 |
+
*/
|
1266 |
+
enqueue(e) {
|
1267 |
+
this._bufferedBytes += e[3][x], this._queue.push(e);
|
1268 |
+
}
|
1269 |
+
/**
|
1270 |
+
* Sends a frame.
|
1271 |
+
*
|
1272 |
+
* @param {Buffer[]} list The frame to send
|
1273 |
+
* @param {Function} [cb] Callback
|
1274 |
+
* @private
|
1275 |
+
*/
|
1276 |
+
sendFrame(e, t) {
|
1277 |
+
e.length === 2 ? (this._socket.cork(), this._socket.write(e[0]), this._socket.write(e[1], t), this._socket.uncork()) : this._socket.write(e[0], t);
|
1278 |
+
}
|
1279 |
+
};
|
1280 |
+
var it = Jt;
|
1281 |
+
const Ks = /* @__PURE__ */ z(it), { kForOnEventAttribute: F, kListener: pe } = U, We = Symbol("kCode"), Ae = Symbol("kData"), Fe = Symbol("kError"), je = Symbol("kMessage"), Ge = Symbol("kReason"), I = Symbol("kTarget"), Ve = Symbol("kType"), He = Symbol("kWasClean");
|
1282 |
+
class B {
|
1283 |
+
/**
|
1284 |
+
* Create a new `Event`.
|
1285 |
+
*
|
1286 |
+
* @param {String} type The name of the event
|
1287 |
+
* @throws {TypeError} If the `type` argument is not specified
|
1288 |
+
*/
|
1289 |
+
constructor(e) {
|
1290 |
+
this[I] = null, this[Ve] = e;
|
1291 |
+
}
|
1292 |
+
/**
|
1293 |
+
* @type {*}
|
1294 |
+
*/
|
1295 |
+
get target() {
|
1296 |
+
return this[I];
|
1297 |
+
}
|
1298 |
+
/**
|
1299 |
+
* @type {String}
|
1300 |
+
*/
|
1301 |
+
get type() {
|
1302 |
+
return this[Ve];
|
1303 |
+
}
|
1304 |
+
}
|
1305 |
+
Object.defineProperty(B.prototype, "target", { enumerable: !0 });
|
1306 |
+
Object.defineProperty(B.prototype, "type", { enumerable: !0 });
|
1307 |
+
class Y extends B {
|
1308 |
+
/**
|
1309 |
+
* Create a new `CloseEvent`.
|
1310 |
+
*
|
1311 |
+
* @param {String} type The name of the event
|
1312 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1313 |
+
* attributes via object members of the same name
|
1314 |
+
* @param {Number} [options.code=0] The status code explaining why the
|
1315 |
+
* connection was closed
|
1316 |
+
* @param {String} [options.reason=''] A human-readable string explaining why
|
1317 |
+
* the connection was closed
|
1318 |
+
* @param {Boolean} [options.wasClean=false] Indicates whether or not the
|
1319 |
+
* connection was cleanly closed
|
1320 |
+
*/
|
1321 |
+
constructor(e, t = {}) {
|
1322 |
+
super(e), this[We] = t.code === void 0 ? 0 : t.code, this[Ge] = t.reason === void 0 ? "" : t.reason, this[He] = t.wasClean === void 0 ? !1 : t.wasClean;
|
1323 |
+
}
|
1324 |
+
/**
|
1325 |
+
* @type {Number}
|
1326 |
+
*/
|
1327 |
+
get code() {
|
1328 |
+
return this[We];
|
1329 |
+
}
|
1330 |
+
/**
|
1331 |
+
* @type {String}
|
1332 |
+
*/
|
1333 |
+
get reason() {
|
1334 |
+
return this[Ge];
|
1335 |
+
}
|
1336 |
+
/**
|
1337 |
+
* @type {Boolean}
|
1338 |
+
*/
|
1339 |
+
get wasClean() {
|
1340 |
+
return this[He];
|
1341 |
+
}
|
1342 |
+
}
|
1343 |
+
Object.defineProperty(Y.prototype, "code", { enumerable: !0 });
|
1344 |
+
Object.defineProperty(Y.prototype, "reason", { enumerable: !0 });
|
1345 |
+
Object.defineProperty(Y.prototype, "wasClean", { enumerable: !0 });
|
1346 |
+
class le extends B {
|
1347 |
+
/**
|
1348 |
+
* Create a new `ErrorEvent`.
|
1349 |
+
*
|
1350 |
+
* @param {String} type The name of the event
|
1351 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1352 |
+
* attributes via object members of the same name
|
1353 |
+
* @param {*} [options.error=null] The error that generated this event
|
1354 |
+
* @param {String} [options.message=''] The error message
|
1355 |
+
*/
|
1356 |
+
constructor(e, t = {}) {
|
1357 |
+
super(e), this[Fe] = t.error === void 0 ? null : t.error, this[je] = t.message === void 0 ? "" : t.message;
|
1358 |
+
}
|
1359 |
+
/**
|
1360 |
+
* @type {*}
|
1361 |
+
*/
|
1362 |
+
get error() {
|
1363 |
+
return this[Fe];
|
1364 |
+
}
|
1365 |
+
/**
|
1366 |
+
* @type {String}
|
1367 |
+
*/
|
1368 |
+
get message() {
|
1369 |
+
return this[je];
|
1370 |
+
}
|
1371 |
+
}
|
1372 |
+
Object.defineProperty(le.prototype, "error", { enumerable: !0 });
|
1373 |
+
Object.defineProperty(le.prototype, "message", { enumerable: !0 });
|
1374 |
+
class xe extends B {
|
1375 |
+
/**
|
1376 |
+
* Create a new `MessageEvent`.
|
1377 |
+
*
|
1378 |
+
* @param {String} type The name of the event
|
1379 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1380 |
+
* attributes via object members of the same name
|
1381 |
+
* @param {*} [options.data=null] The message content
|
1382 |
+
*/
|
1383 |
+
constructor(e, t = {}) {
|
1384 |
+
super(e), this[Ae] = t.data === void 0 ? null : t.data;
|
1385 |
+
}
|
1386 |
+
/**
|
1387 |
+
* @type {*}
|
1388 |
+
*/
|
1389 |
+
get data() {
|
1390 |
+
return this[Ae];
|
1391 |
+
}
|
1392 |
+
}
|
1393 |
+
Object.defineProperty(xe.prototype, "data", { enumerable: !0 });
|
1394 |
+
const es = {
|
1395 |
+
/**
|
1396 |
+
* Register an event listener.
|
1397 |
+
*
|
1398 |
+
* @param {String} type A string representing the event type to listen for
|
1399 |
+
* @param {(Function|Object)} handler The listener to add
|
1400 |
+
* @param {Object} [options] An options object specifies characteristics about
|
1401 |
+
* the event listener
|
1402 |
+
* @param {Boolean} [options.once=false] A `Boolean` indicating that the
|
1403 |
+
* listener should be invoked at most once after being added. If `true`,
|
1404 |
+
* the listener would be automatically removed when invoked.
|
1405 |
+
* @public
|
1406 |
+
*/
|
1407 |
+
addEventListener(s, e, t = {}) {
|
1408 |
+
for (const i of this.listeners(s))
|
1409 |
+
if (!t[F] && i[pe] === e && !i[F])
|
1410 |
+
return;
|
1411 |
+
let r;
|
1412 |
+
if (s === "message")
|
1413 |
+
r = function(n, o) {
|
1414 |
+
const l = new xe("message", {
|
1415 |
+
data: o ? n : n.toString()
|
1416 |
+
});
|
1417 |
+
l[I] = this, Z(e, this, l);
|
1418 |
+
};
|
1419 |
+
else if (s === "close")
|
1420 |
+
r = function(n, o) {
|
1421 |
+
const l = new Y("close", {
|
1422 |
+
code: n,
|
1423 |
+
reason: o.toString(),
|
1424 |
+
wasClean: this._closeFrameReceived && this._closeFrameSent
|
1425 |
+
});
|
1426 |
+
l[I] = this, Z(e, this, l);
|
1427 |
+
};
|
1428 |
+
else if (s === "error")
|
1429 |
+
r = function(n) {
|
1430 |
+
const o = new le("error", {
|
1431 |
+
error: n,
|
1432 |
+
message: n.message
|
1433 |
+
});
|
1434 |
+
o[I] = this, Z(e, this, o);
|
1435 |
+
};
|
1436 |
+
else if (s === "open")
|
1437 |
+
r = function() {
|
1438 |
+
const n = new B("open");
|
1439 |
+
n[I] = this, Z(e, this, n);
|
1440 |
+
};
|
1441 |
+
else
|
1442 |
+
return;
|
1443 |
+
r[F] = !!t[F], r[pe] = e, t.once ? this.once(s, r) : this.on(s, r);
|
1444 |
+
},
|
1445 |
+
/**
|
1446 |
+
* Remove an event listener.
|
1447 |
+
*
|
1448 |
+
* @param {String} type A string representing the event type to remove
|
1449 |
+
* @param {(Function|Object)} handler The listener to remove
|
1450 |
+
* @public
|
1451 |
+
*/
|
1452 |
+
removeEventListener(s, e) {
|
1453 |
+
for (const t of this.listeners(s))
|
1454 |
+
if (t[pe] === e && !t[F]) {
|
1455 |
+
this.removeListener(s, t);
|
1456 |
+
break;
|
1457 |
+
}
|
1458 |
+
}
|
1459 |
+
};
|
1460 |
+
var ts = {
|
1461 |
+
CloseEvent: Y,
|
1462 |
+
ErrorEvent: le,
|
1463 |
+
Event: B,
|
1464 |
+
EventTarget: es,
|
1465 |
+
MessageEvent: xe
|
1466 |
+
};
|
1467 |
+
function Z(s, e, t) {
|
1468 |
+
typeof s == "object" && s.handleEvent ? s.handleEvent.call(s, t) : s.call(e, t);
|
1469 |
+
}
|
1470 |
+
const { tokenChars: j } = ae;
|
1471 |
+
function k(s, e, t) {
|
1472 |
+
s[e] === void 0 ? s[e] = [t] : s[e].push(t);
|
1473 |
+
}
|
1474 |
+
function ss(s) {
|
1475 |
+
const e = /* @__PURE__ */ Object.create(null);
|
1476 |
+
let t = /* @__PURE__ */ Object.create(null), r = !1, i = !1, n = !1, o, l, f = -1, a = -1, c = -1, h = 0;
|
1477 |
+
for (; h < s.length; h++)
|
1478 |
+
if (a = s.charCodeAt(h), o === void 0)
|
1479 |
+
if (c === -1 && j[a] === 1)
|
1480 |
+
f === -1 && (f = h);
|
1481 |
+
else if (h !== 0 && (a === 32 || a === 9))
|
1482 |
+
c === -1 && f !== -1 && (c = h);
|
1483 |
+
else if (a === 59 || a === 44) {
|
1484 |
+
if (f === -1)
|
1485 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1486 |
+
c === -1 && (c = h);
|
1487 |
+
const v = s.slice(f, c);
|
1488 |
+
a === 44 ? (k(e, v, t), t = /* @__PURE__ */ Object.create(null)) : o = v, f = c = -1;
|
1489 |
+
} else
|
1490 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1491 |
+
else if (l === void 0)
|
1492 |
+
if (c === -1 && j[a] === 1)
|
1493 |
+
f === -1 && (f = h);
|
1494 |
+
else if (a === 32 || a === 9)
|
1495 |
+
c === -1 && f !== -1 && (c = h);
|
1496 |
+
else if (a === 59 || a === 44) {
|
1497 |
+
if (f === -1)
|
1498 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1499 |
+
c === -1 && (c = h), k(t, s.slice(f, c), !0), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), f = c = -1;
|
1500 |
+
} else if (a === 61 && f !== -1 && c === -1)
|
1501 |
+
l = s.slice(f, h), f = c = -1;
|
1502 |
+
else
|
1503 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1504 |
+
else if (i) {
|
1505 |
+
if (j[a] !== 1)
|
1506 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1507 |
+
f === -1 ? f = h : r || (r = !0), i = !1;
|
1508 |
+
} else if (n)
|
1509 |
+
if (j[a] === 1)
|
1510 |
+
f === -1 && (f = h);
|
1511 |
+
else if (a === 34 && f !== -1)
|
1512 |
+
n = !1, c = h;
|
1513 |
+
else if (a === 92)
|
1514 |
+
i = !0;
|
1515 |
+
else
|
1516 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1517 |
+
else if (a === 34 && s.charCodeAt(h - 1) === 61)
|
1518 |
+
n = !0;
|
1519 |
+
else if (c === -1 && j[a] === 1)
|
1520 |
+
f === -1 && (f = h);
|
1521 |
+
else if (f !== -1 && (a === 32 || a === 9))
|
1522 |
+
c === -1 && (c = h);
|
1523 |
+
else if (a === 59 || a === 44) {
|
1524 |
+
if (f === -1)
|
1525 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1526 |
+
c === -1 && (c = h);
|
1527 |
+
let v = s.slice(f, c);
|
1528 |
+
r && (v = v.replace(/\\/g, ""), r = !1), k(t, l, v), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), l = void 0, f = c = -1;
|
1529 |
+
} else
|
1530 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1531 |
+
if (f === -1 || n || a === 32 || a === 9)
|
1532 |
+
throw new SyntaxError("Unexpected end of input");
|
1533 |
+
c === -1 && (c = h);
|
1534 |
+
const p = s.slice(f, c);
|
1535 |
+
return o === void 0 ? k(e, p, t) : (l === void 0 ? k(t, p, !0) : r ? k(t, l, p.replace(/\\/g, "")) : k(t, l, p), k(e, o, t)), e;
|
1536 |
+
}
|
1537 |
+
function rs(s) {
|
1538 |
+
return Object.keys(s).map((e) => {
|
1539 |
+
let t = s[e];
|
1540 |
+
return Array.isArray(t) || (t = [t]), t.map((r) => [e].concat(
|
1541 |
+
Object.keys(r).map((i) => {
|
1542 |
+
let n = r[i];
|
1543 |
+
return Array.isArray(n) || (n = [n]), n.map((o) => o === !0 ? i : `${i}=${o}`).join("; ");
|
1544 |
+
})
|
1545 |
+
).join("; ")).join(", ");
|
1546 |
+
}).join(", ");
|
1547 |
+
}
|
1548 |
+
var nt = { format: rs, parse: ss };
|
1549 |
+
const is = S, ns = S, os = S, ot = S, as = S, { randomBytes: ls, createHash: fs } = S, { URL: me } = S, T = oe, hs = rt, cs = it, {
|
1550 |
+
BINARY_TYPES: ze,
|
1551 |
+
EMPTY_BUFFER: Q,
|
1552 |
+
GUID: us,
|
1553 |
+
kForOnEventAttribute: ge,
|
1554 |
+
kListener: ds,
|
1555 |
+
kStatusCode: _s,
|
1556 |
+
kWebSocket: y,
|
1557 |
+
NOOP: at
|
1558 |
+
} = U, {
|
1559 |
+
EventTarget: { addEventListener: ps, removeEventListener: ms }
|
1560 |
+
} = ts, { format: gs, parse: ys } = nt, { toBuffer: vs } = ne, Ss = 30 * 1e3, lt = Symbol("kAborted"), ye = [8, 13], O = ["CONNECTING", "OPEN", "CLOSING", "CLOSED"], Es = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/;
|
1561 |
+
let m = class d extends is {
|
1562 |
+
/**
|
1563 |
+
* Create a new `WebSocket`.
|
1564 |
+
*
|
1565 |
+
* @param {(String|URL)} address The URL to which to connect
|
1566 |
+
* @param {(String|String[])} [protocols] The subprotocols
|
1567 |
+
* @param {Object} [options] Connection options
|
1568 |
+
*/
|
1569 |
+
constructor(e, t, r) {
|
1570 |
+
super(), this._binaryType = ze[0], this._closeCode = 1006, this._closeFrameReceived = !1, this._closeFrameSent = !1, this._closeMessage = Q, this._closeTimer = null, this._extensions = {}, this._paused = !1, this._protocol = "", this._readyState = d.CONNECTING, this._receiver = null, this._sender = null, this._socket = null, e !== null ? (this._bufferedAmount = 0, this._isServer = !1, this._redirects = 0, t === void 0 ? t = [] : Array.isArray(t) || (typeof t == "object" && t !== null ? (r = t, t = []) : t = [t]), ht(this, e, t, r)) : this._isServer = !0;
|
1571 |
+
}
|
1572 |
+
/**
|
1573 |
+
* This deviates from the WHATWG interface since ws doesn't support the
|
1574 |
+
* required default "blob" type (instead we define a custom "nodebuffer"
|
1575 |
+
* type).
|
1576 |
+
*
|
1577 |
+
* @type {String}
|
1578 |
+
*/
|
1579 |
+
get binaryType() {
|
1580 |
+
return this._binaryType;
|
1581 |
+
}
|
1582 |
+
set binaryType(e) {
|
1583 |
+
ze.includes(e) && (this._binaryType = e, this._receiver && (this._receiver._binaryType = e));
|
1584 |
+
}
|
1585 |
+
/**
|
1586 |
+
* @type {Number}
|
1587 |
+
*/
|
1588 |
+
get bufferedAmount() {
|
1589 |
+
return this._socket ? this._socket._writableState.length + this._sender._bufferedBytes : this._bufferedAmount;
|
1590 |
+
}
|
1591 |
+
/**
|
1592 |
+
* @type {String}
|
1593 |
+
*/
|
1594 |
+
get extensions() {
|
1595 |
+
return Object.keys(this._extensions).join();
|
1596 |
+
}
|
1597 |
+
/**
|
1598 |
+
* @type {Boolean}
|
1599 |
+
*/
|
1600 |
+
get isPaused() {
|
1601 |
+
return this._paused;
|
1602 |
+
}
|
1603 |
+
/**
|
1604 |
+
* @type {Function}
|
1605 |
+
*/
|
1606 |
+
/* istanbul ignore next */
|
1607 |
+
get onclose() {
|
1608 |
+
return null;
|
1609 |
+
}
|
1610 |
+
/**
|
1611 |
+
* @type {Function}
|
1612 |
+
*/
|
1613 |
+
/* istanbul ignore next */
|
1614 |
+
get onerror() {
|
1615 |
+
return null;
|
1616 |
+
}
|
1617 |
+
/**
|
1618 |
+
* @type {Function}
|
1619 |
+
*/
|
1620 |
+
/* istanbul ignore next */
|
1621 |
+
get onopen() {
|
1622 |
+
return null;
|
1623 |
+
}
|
1624 |
+
/**
|
1625 |
+
* @type {Function}
|
1626 |
+
*/
|
1627 |
+
/* istanbul ignore next */
|
1628 |
+
get onmessage() {
|
1629 |
+
return null;
|
1630 |
+
}
|
1631 |
+
/**
|
1632 |
+
* @type {String}
|
1633 |
+
*/
|
1634 |
+
get protocol() {
|
1635 |
+
return this._protocol;
|
1636 |
+
}
|
1637 |
+
/**
|
1638 |
+
* @type {Number}
|
1639 |
+
*/
|
1640 |
+
get readyState() {
|
1641 |
+
return this._readyState;
|
1642 |
+
}
|
1643 |
+
/**
|
1644 |
+
* @type {String}
|
1645 |
+
*/
|
1646 |
+
get url() {
|
1647 |
+
return this._url;
|
1648 |
+
}
|
1649 |
+
/**
|
1650 |
+
* Set up the socket and the internal resources.
|
1651 |
+
*
|
1652 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
1653 |
+
* server and client
|
1654 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
1655 |
+
* @param {Object} options Options object
|
1656 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1657 |
+
* masking key
|
1658 |
+
* @param {Number} [options.maxPayload=0] The maximum allowed message size
|
1659 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
1660 |
+
* not to skip UTF-8 validation for text and close messages
|
1661 |
+
* @private
|
1662 |
+
*/
|
1663 |
+
setSocket(e, t, r) {
|
1664 |
+
const i = new hs({
|
1665 |
+
binaryType: this.binaryType,
|
1666 |
+
extensions: this._extensions,
|
1667 |
+
isServer: this._isServer,
|
1668 |
+
maxPayload: r.maxPayload,
|
1669 |
+
skipUTF8Validation: r.skipUTF8Validation
|
1670 |
+
});
|
1671 |
+
this._sender = new cs(e, this._extensions, r.generateMask), this._receiver = i, this._socket = e, i[y] = this, e[y] = this, i.on("conclude", ks), i.on("drain", ws), i.on("error", Os), i.on("message", Cs), i.on("ping", Ts), i.on("pong", Ls), e.setTimeout(0), e.setNoDelay(), t.length > 0 && e.unshift(t), e.on("close", ut), e.on("data", fe), e.on("end", dt), e.on("error", _t), this._readyState = d.OPEN, this.emit("open");
|
1672 |
+
}
|
1673 |
+
/**
|
1674 |
+
* Emit the `'close'` event.
|
1675 |
+
*
|
1676 |
+
* @private
|
1677 |
+
*/
|
1678 |
+
emitClose() {
|
1679 |
+
if (!this._socket) {
|
1680 |
+
this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
|
1681 |
+
return;
|
1682 |
+
}
|
1683 |
+
this._extensions[T.extensionName] && this._extensions[T.extensionName].cleanup(), this._receiver.removeAllListeners(), this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
|
1684 |
+
}
|
1685 |
+
/**
|
1686 |
+
* Start a closing handshake.
|
1687 |
+
*
|
1688 |
+
* +----------+ +-----------+ +----------+
|
1689 |
+
* - - -|ws.close()|-->|close frame|-->|ws.close()|- - -
|
1690 |
+
* | +----------+ +-----------+ +----------+ |
|
1691 |
+
* +----------+ +-----------+ |
|
1692 |
+
* CLOSING |ws.close()|<--|close frame|<--+-----+ CLOSING
|
1693 |
+
* +----------+ +-----------+ |
|
1694 |
+
* | | | +---+ |
|
1695 |
+
* +------------------------+-->|fin| - - - -
|
1696 |
+
* | +---+ | +---+
|
1697 |
+
* - - - - -|fin|<---------------------+
|
1698 |
+
* +---+
|
1699 |
+
*
|
1700 |
+
* @param {Number} [code] Status code explaining why the connection is closing
|
1701 |
+
* @param {(String|Buffer)} [data] The reason why the connection is
|
1702 |
+
* closing
|
1703 |
+
* @public
|
1704 |
+
*/
|
1705 |
+
close(e, t) {
|
1706 |
+
if (this.readyState !== d.CLOSED) {
|
1707 |
+
if (this.readyState === d.CONNECTING) {
|
1708 |
+
b(this, this._req, "WebSocket was closed before the connection was established");
|
1709 |
+
return;
|
1710 |
+
}
|
1711 |
+
if (this.readyState === d.CLOSING) {
|
1712 |
+
this._closeFrameSent && (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end();
|
1713 |
+
return;
|
1714 |
+
}
|
1715 |
+
this._readyState = d.CLOSING, this._sender.close(e, t, !this._isServer, (r) => {
|
1716 |
+
r || (this._closeFrameSent = !0, (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end());
|
1717 |
+
}), this._closeTimer = setTimeout(
|
1718 |
+
this._socket.destroy.bind(this._socket),
|
1719 |
+
Ss
|
1720 |
+
);
|
1721 |
+
}
|
1722 |
+
}
|
1723 |
+
/**
|
1724 |
+
* Pause the socket.
|
1725 |
+
*
|
1726 |
+
* @public
|
1727 |
+
*/
|
1728 |
+
pause() {
|
1729 |
+
this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !0, this._socket.pause());
|
1730 |
+
}
|
1731 |
+
/**
|
1732 |
+
* Send a ping.
|
1733 |
+
*
|
1734 |
+
* @param {*} [data] The data to send
|
1735 |
+
* @param {Boolean} [mask] Indicates whether or not to mask `data`
|
1736 |
+
* @param {Function} [cb] Callback which is executed when the ping is sent
|
1737 |
+
* @public
|
1738 |
+
*/
|
1739 |
+
ping(e, t, r) {
|
1740 |
+
if (this.readyState === d.CONNECTING)
|
1741 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1742 |
+
if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1743 |
+
ve(this, e, r);
|
1744 |
+
return;
|
1745 |
+
}
|
1746 |
+
t === void 0 && (t = !this._isServer), this._sender.ping(e || Q, t, r);
|
1747 |
+
}
|
1748 |
+
/**
|
1749 |
+
* Send a pong.
|
1750 |
+
*
|
1751 |
+
* @param {*} [data] The data to send
|
1752 |
+
* @param {Boolean} [mask] Indicates whether or not to mask `data`
|
1753 |
+
* @param {Function} [cb] Callback which is executed when the pong is sent
|
1754 |
+
* @public
|
1755 |
+
*/
|
1756 |
+
pong(e, t, r) {
|
1757 |
+
if (this.readyState === d.CONNECTING)
|
1758 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1759 |
+
if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1760 |
+
ve(this, e, r);
|
1761 |
+
return;
|
1762 |
+
}
|
1763 |
+
t === void 0 && (t = !this._isServer), this._sender.pong(e || Q, t, r);
|
1764 |
+
}
|
1765 |
+
/**
|
1766 |
+
* Resume the socket.
|
1767 |
+
*
|
1768 |
+
* @public
|
1769 |
+
*/
|
1770 |
+
resume() {
|
1771 |
+
this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !1, this._receiver._writableState.needDrain || this._socket.resume());
|
1772 |
+
}
|
1773 |
+
/**
|
1774 |
+
* Send a data message.
|
1775 |
+
*
|
1776 |
+
* @param {*} data The message to send
|
1777 |
+
* @param {Object} [options] Options object
|
1778 |
+
* @param {Boolean} [options.binary] Specifies whether `data` is binary or
|
1779 |
+
* text
|
1780 |
+
* @param {Boolean} [options.compress] Specifies whether or not to compress
|
1781 |
+
* `data`
|
1782 |
+
* @param {Boolean} [options.fin=true] Specifies whether the fragment is the
|
1783 |
+
* last one
|
1784 |
+
* @param {Boolean} [options.mask] Specifies whether or not to mask `data`
|
1785 |
+
* @param {Function} [cb] Callback which is executed when data is written out
|
1786 |
+
* @public
|
1787 |
+
*/
|
1788 |
+
send(e, t, r) {
|
1789 |
+
if (this.readyState === d.CONNECTING)
|
1790 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1791 |
+
if (typeof t == "function" && (r = t, t = {}), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1792 |
+
ve(this, e, r);
|
1793 |
+
return;
|
1794 |
+
}
|
1795 |
+
const i = {
|
1796 |
+
binary: typeof e != "string",
|
1797 |
+
mask: !this._isServer,
|
1798 |
+
compress: !0,
|
1799 |
+
fin: !0,
|
1800 |
+
...t
|
1801 |
+
};
|
1802 |
+
this._extensions[T.extensionName] || (i.compress = !1), this._sender.send(e || Q, i, r);
|
1803 |
+
}
|
1804 |
+
/**
|
1805 |
+
* Forcibly close the connection.
|
1806 |
+
*
|
1807 |
+
* @public
|
1808 |
+
*/
|
1809 |
+
terminate() {
|
1810 |
+
if (this.readyState !== d.CLOSED) {
|
1811 |
+
if (this.readyState === d.CONNECTING) {
|
1812 |
+
b(this, this._req, "WebSocket was closed before the connection was established");
|
1813 |
+
return;
|
1814 |
+
}
|
1815 |
+
this._socket && (this._readyState = d.CLOSING, this._socket.destroy());
|
1816 |
+
}
|
1817 |
+
}
|
1818 |
+
};
|
1819 |
+
Object.defineProperty(m, "CONNECTING", {
|
1820 |
+
enumerable: !0,
|
1821 |
+
value: O.indexOf("CONNECTING")
|
1822 |
+
});
|
1823 |
+
Object.defineProperty(m.prototype, "CONNECTING", {
|
1824 |
+
enumerable: !0,
|
1825 |
+
value: O.indexOf("CONNECTING")
|
1826 |
+
});
|
1827 |
+
Object.defineProperty(m, "OPEN", {
|
1828 |
+
enumerable: !0,
|
1829 |
+
value: O.indexOf("OPEN")
|
1830 |
+
});
|
1831 |
+
Object.defineProperty(m.prototype, "OPEN", {
|
1832 |
+
enumerable: !0,
|
1833 |
+
value: O.indexOf("OPEN")
|
1834 |
+
});
|
1835 |
+
Object.defineProperty(m, "CLOSING", {
|
1836 |
+
enumerable: !0,
|
1837 |
+
value: O.indexOf("CLOSING")
|
1838 |
+
});
|
1839 |
+
Object.defineProperty(m.prototype, "CLOSING", {
|
1840 |
+
enumerable: !0,
|
1841 |
+
value: O.indexOf("CLOSING")
|
1842 |
+
});
|
1843 |
+
Object.defineProperty(m, "CLOSED", {
|
1844 |
+
enumerable: !0,
|
1845 |
+
value: O.indexOf("CLOSED")
|
1846 |
+
});
|
1847 |
+
Object.defineProperty(m.prototype, "CLOSED", {
|
1848 |
+
enumerable: !0,
|
1849 |
+
value: O.indexOf("CLOSED")
|
1850 |
+
});
|
1851 |
+
[
|
1852 |
+
"binaryType",
|
1853 |
+
"bufferedAmount",
|
1854 |
+
"extensions",
|
1855 |
+
"isPaused",
|
1856 |
+
"protocol",
|
1857 |
+
"readyState",
|
1858 |
+
"url"
|
1859 |
+
].forEach((s) => {
|
1860 |
+
Object.defineProperty(m.prototype, s, { enumerable: !0 });
|
1861 |
+
});
|
1862 |
+
["open", "error", "close", "message"].forEach((s) => {
|
1863 |
+
Object.defineProperty(m.prototype, `on${s}`, {
|
1864 |
+
enumerable: !0,
|
1865 |
+
get() {
|
1866 |
+
for (const e of this.listeners(s))
|
1867 |
+
if (e[ge])
|
1868 |
+
return e[ds];
|
1869 |
+
return null;
|
1870 |
+
},
|
1871 |
+
set(e) {
|
1872 |
+
for (const t of this.listeners(s))
|
1873 |
+
if (t[ge]) {
|
1874 |
+
this.removeListener(s, t);
|
1875 |
+
break;
|
1876 |
+
}
|
1877 |
+
typeof e == "function" && this.addEventListener(s, e, {
|
1878 |
+
[ge]: !0
|
1879 |
+
});
|
1880 |
+
}
|
1881 |
+
});
|
1882 |
+
});
|
1883 |
+
m.prototype.addEventListener = ps;
|
1884 |
+
m.prototype.removeEventListener = ms;
|
1885 |
+
var ft = m;
|
1886 |
+
function ht(s, e, t, r) {
|
1887 |
+
const i = {
|
1888 |
+
protocolVersion: ye[1],
|
1889 |
+
maxPayload: 104857600,
|
1890 |
+
skipUTF8Validation: !1,
|
1891 |
+
perMessageDeflate: !0,
|
1892 |
+
followRedirects: !1,
|
1893 |
+
maxRedirects: 10,
|
1894 |
+
...r,
|
1895 |
+
createConnection: void 0,
|
1896 |
+
socketPath: void 0,
|
1897 |
+
hostname: void 0,
|
1898 |
+
protocol: void 0,
|
1899 |
+
timeout: void 0,
|
1900 |
+
method: "GET",
|
1901 |
+
host: void 0,
|
1902 |
+
path: void 0,
|
1903 |
+
port: void 0
|
1904 |
+
};
|
1905 |
+
if (!ye.includes(i.protocolVersion))
|
1906 |
+
throw new RangeError(
|
1907 |
+
`Unsupported protocol version: ${i.protocolVersion} (supported versions: ${ye.join(", ")})`
|
1908 |
+
);
|
1909 |
+
let n;
|
1910 |
+
if (e instanceof me)
|
1911 |
+
n = e, s._url = e.href;
|
1912 |
+
else {
|
1913 |
+
try {
|
1914 |
+
n = new me(e);
|
1915 |
+
} catch {
|
1916 |
+
throw new SyntaxError(`Invalid URL: ${e}`);
|
1917 |
+
}
|
1918 |
+
s._url = e;
|
1919 |
+
}
|
1920 |
+
const o = n.protocol === "wss:", l = n.protocol === "ws+unix:";
|
1921 |
+
let f;
|
1922 |
+
if (n.protocol !== "ws:" && !o && !l ? f = `The URL's protocol must be one of "ws:", "wss:", or "ws+unix:"` : l && !n.pathname ? f = "The URL's pathname is empty" : n.hash && (f = "The URL contains a fragment identifier"), f) {
|
1923 |
+
const u = new SyntaxError(f);
|
1924 |
+
if (s._redirects === 0)
|
1925 |
+
throw u;
|
1926 |
+
ee(s, u);
|
1927 |
+
return;
|
1928 |
+
}
|
1929 |
+
const a = o ? 443 : 80, c = ls(16).toString("base64"), h = o ? ns.request : os.request, p = /* @__PURE__ */ new Set();
|
1930 |
+
let v;
|
1931 |
+
if (i.createConnection = o ? xs : bs, i.defaultPort = i.defaultPort || a, i.port = n.port || a, i.host = n.hostname.startsWith("[") ? n.hostname.slice(1, -1) : n.hostname, i.headers = {
|
1932 |
+
...i.headers,
|
1933 |
+
"Sec-WebSocket-Version": i.protocolVersion,
|
1934 |
+
"Sec-WebSocket-Key": c,
|
1935 |
+
Connection: "Upgrade",
|
1936 |
+
Upgrade: "websocket"
|
1937 |
+
}, i.path = n.pathname + n.search, i.timeout = i.handshakeTimeout, i.perMessageDeflate && (v = new T(
|
1938 |
+
i.perMessageDeflate !== !0 ? i.perMessageDeflate : {},
|
1939 |
+
!1,
|
1940 |
+
i.maxPayload
|
1941 |
+
), i.headers["Sec-WebSocket-Extensions"] = gs({
|
1942 |
+
[T.extensionName]: v.offer()
|
1943 |
+
})), t.length) {
|
1944 |
+
for (const u of t) {
|
1945 |
+
if (typeof u != "string" || !Es.test(u) || p.has(u))
|
1946 |
+
throw new SyntaxError(
|
1947 |
+
"An invalid or duplicated subprotocol was specified"
|
1948 |
+
);
|
1949 |
+
p.add(u);
|
1950 |
+
}
|
1951 |
+
i.headers["Sec-WebSocket-Protocol"] = t.join(",");
|
1952 |
+
}
|
1953 |
+
if (i.origin && (i.protocolVersion < 13 ? i.headers["Sec-WebSocket-Origin"] = i.origin : i.headers.Origin = i.origin), (n.username || n.password) && (i.auth = `${n.username}:${n.password}`), l) {
|
1954 |
+
const u = i.path.split(":");
|
1955 |
+
i.socketPath = u[0], i.path = u[1];
|
1956 |
+
}
|
1957 |
+
let _;
|
1958 |
+
if (i.followRedirects) {
|
1959 |
+
if (s._redirects === 0) {
|
1960 |
+
s._originalIpc = l, s._originalSecure = o, s._originalHostOrSocketPath = l ? i.socketPath : n.host;
|
1961 |
+
const u = r && r.headers;
|
1962 |
+
if (r = { ...r, headers: {} }, u)
|
1963 |
+
for (const [E, $] of Object.entries(u))
|
1964 |
+
r.headers[E.toLowerCase()] = $;
|
1965 |
+
} else if (s.listenerCount("redirect") === 0) {
|
1966 |
+
const u = l ? s._originalIpc ? i.socketPath === s._originalHostOrSocketPath : !1 : s._originalIpc ? !1 : n.host === s._originalHostOrSocketPath;
|
1967 |
+
(!u || s._originalSecure && !o) && (delete i.headers.authorization, delete i.headers.cookie, u || delete i.headers.host, i.auth = void 0);
|
1968 |
+
}
|
1969 |
+
i.auth && !r.headers.authorization && (r.headers.authorization = "Basic " + Buffer.from(i.auth).toString("base64")), _ = s._req = h(i), s._redirects && s.emit("redirect", s.url, _);
|
1970 |
+
} else
|
1971 |
+
_ = s._req = h(i);
|
1972 |
+
i.timeout && _.on("timeout", () => {
|
1973 |
+
b(s, _, "Opening handshake has timed out");
|
1974 |
+
}), _.on("error", (u) => {
|
1975 |
+
_ === null || _[lt] || (_ = s._req = null, ee(s, u));
|
1976 |
+
}), _.on("response", (u) => {
|
1977 |
+
const E = u.headers.location, $ = u.statusCode;
|
1978 |
+
if (E && i.followRedirects && $ >= 300 && $ < 400) {
|
1979 |
+
if (++s._redirects > i.maxRedirects) {
|
1980 |
+
b(s, _, "Maximum redirects exceeded");
|
1981 |
+
return;
|
1982 |
+
}
|
1983 |
+
_.abort();
|
1984 |
+
let q;
|
1985 |
+
try {
|
1986 |
+
q = new me(E, e);
|
1987 |
+
} catch {
|
1988 |
+
const L = new SyntaxError(`Invalid URL: ${E}`);
|
1989 |
+
ee(s, L);
|
1990 |
+
return;
|
1991 |
+
}
|
1992 |
+
ht(s, q, t, r);
|
1993 |
+
} else
|
1994 |
+
s.emit("unexpected-response", _, u) || b(
|
1995 |
+
s,
|
1996 |
+
_,
|
1997 |
+
`Unexpected server response: ${u.statusCode}`
|
1998 |
+
);
|
1999 |
+
}), _.on("upgrade", (u, E, $) => {
|
2000 |
+
if (s.emit("upgrade", u), s.readyState !== m.CONNECTING)
|
2001 |
+
return;
|
2002 |
+
if (_ = s._req = null, u.headers.upgrade.toLowerCase() !== "websocket") {
|
2003 |
+
b(s, E, "Invalid Upgrade header");
|
2004 |
+
return;
|
2005 |
+
}
|
2006 |
+
const q = fs("sha1").update(c + us).digest("base64");
|
2007 |
+
if (u.headers["sec-websocket-accept"] !== q) {
|
2008 |
+
b(s, E, "Invalid Sec-WebSocket-Accept header");
|
2009 |
+
return;
|
2010 |
+
}
|
2011 |
+
const D = u.headers["sec-websocket-protocol"];
|
2012 |
+
let L;
|
2013 |
+
if (D !== void 0 ? p.size ? p.has(D) || (L = "Server sent an invalid subprotocol") : L = "Server sent a subprotocol but none was requested" : p.size && (L = "Server sent no subprotocol"), L) {
|
2014 |
+
b(s, E, L);
|
2015 |
+
return;
|
2016 |
+
}
|
2017 |
+
D && (s._protocol = D);
|
2018 |
+
const ke = u.headers["sec-websocket-extensions"];
|
2019 |
+
if (ke !== void 0) {
|
2020 |
+
if (!v) {
|
2021 |
+
b(s, E, "Server sent a Sec-WebSocket-Extensions header but no extension was requested");
|
2022 |
+
return;
|
2023 |
+
}
|
2024 |
+
let he;
|
2025 |
+
try {
|
2026 |
+
he = ys(ke);
|
2027 |
+
} catch {
|
2028 |
+
b(s, E, "Invalid Sec-WebSocket-Extensions header");
|
2029 |
+
return;
|
2030 |
+
}
|
2031 |
+
const we = Object.keys(he);
|
2032 |
+
if (we.length !== 1 || we[0] !== T.extensionName) {
|
2033 |
+
b(s, E, "Server indicated an extension that was not requested");
|
2034 |
+
return;
|
2035 |
+
}
|
2036 |
+
try {
|
2037 |
+
v.accept(he[T.extensionName]);
|
2038 |
+
} catch {
|
2039 |
+
b(s, E, "Invalid Sec-WebSocket-Extensions header");
|
2040 |
+
return;
|
2041 |
+
}
|
2042 |
+
s._extensions[T.extensionName] = v;
|
2043 |
+
}
|
2044 |
+
s.setSocket(E, $, {
|
2045 |
+
generateMask: i.generateMask,
|
2046 |
+
maxPayload: i.maxPayload,
|
2047 |
+
skipUTF8Validation: i.skipUTF8Validation
|
2048 |
+
});
|
2049 |
+
}), i.finishRequest ? i.finishRequest(_, s) : _.end();
|
2050 |
+
}
|
2051 |
+
function ee(s, e) {
|
2052 |
+
s._readyState = m.CLOSING, s.emit("error", e), s.emitClose();
|
2053 |
+
}
|
2054 |
+
function bs(s) {
|
2055 |
+
return s.path = s.socketPath, ot.connect(s);
|
2056 |
+
}
|
2057 |
+
function xs(s) {
|
2058 |
+
return s.path = void 0, !s.servername && s.servername !== "" && (s.servername = ot.isIP(s.host) ? "" : s.host), as.connect(s);
|
2059 |
+
}
|
2060 |
+
function b(s, e, t) {
|
2061 |
+
s._readyState = m.CLOSING;
|
2062 |
+
const r = new Error(t);
|
2063 |
+
Error.captureStackTrace(r, b), e.setHeader ? (e[lt] = !0, e.abort(), e.socket && !e.socket.destroyed && e.socket.destroy(), process.nextTick(ee, s, r)) : (e.destroy(r), e.once("error", s.emit.bind(s, "error")), e.once("close", s.emitClose.bind(s)));
|
2064 |
+
}
|
2065 |
+
function ve(s, e, t) {
|
2066 |
+
if (e) {
|
2067 |
+
const r = vs(e).length;
|
2068 |
+
s._socket ? s._sender._bufferedBytes += r : s._bufferedAmount += r;
|
2069 |
+
}
|
2070 |
+
if (t) {
|
2071 |
+
const r = new Error(
|
2072 |
+
`WebSocket is not open: readyState ${s.readyState} (${O[s.readyState]})`
|
2073 |
+
);
|
2074 |
+
process.nextTick(t, r);
|
2075 |
+
}
|
2076 |
+
}
|
2077 |
+
function ks(s, e) {
|
2078 |
+
const t = this[y];
|
2079 |
+
t._closeFrameReceived = !0, t._closeMessage = e, t._closeCode = s, t._socket[y] !== void 0 && (t._socket.removeListener("data", fe), process.nextTick(ct, t._socket), s === 1005 ? t.close() : t.close(s, e));
|
2080 |
+
}
|
2081 |
+
function ws() {
|
2082 |
+
const s = this[y];
|
2083 |
+
s.isPaused || s._socket.resume();
|
2084 |
+
}
|
2085 |
+
function Os(s) {
|
2086 |
+
const e = this[y];
|
2087 |
+
e._socket[y] !== void 0 && (e._socket.removeListener("data", fe), process.nextTick(ct, e._socket), e.close(s[_s])), e.emit("error", s);
|
2088 |
+
}
|
2089 |
+
function Ye() {
|
2090 |
+
this[y].emitClose();
|
2091 |
+
}
|
2092 |
+
function Cs(s, e) {
|
2093 |
+
this[y].emit("message", s, e);
|
2094 |
+
}
|
2095 |
+
function Ts(s) {
|
2096 |
+
const e = this[y];
|
2097 |
+
e.pong(s, !e._isServer, at), e.emit("ping", s);
|
2098 |
+
}
|
2099 |
+
function Ls(s) {
|
2100 |
+
this[y].emit("pong", s);
|
2101 |
+
}
|
2102 |
+
function ct(s) {
|
2103 |
+
s.resume();
|
2104 |
+
}
|
2105 |
+
function ut() {
|
2106 |
+
const s = this[y];
|
2107 |
+
this.removeListener("close", ut), this.removeListener("data", fe), this.removeListener("end", dt), s._readyState = m.CLOSING;
|
2108 |
+
let e;
|
2109 |
+
!this._readableState.endEmitted && !s._closeFrameReceived && !s._receiver._writableState.errorEmitted && (e = s._socket.read()) !== null && s._receiver.write(e), s._receiver.end(), this[y] = void 0, clearTimeout(s._closeTimer), s._receiver._writableState.finished || s._receiver._writableState.errorEmitted ? s.emitClose() : (s._receiver.on("error", Ye), s._receiver.on("finish", Ye));
|
2110 |
+
}
|
2111 |
+
function fe(s) {
|
2112 |
+
this[y]._receiver.write(s) || this.pause();
|
2113 |
+
}
|
2114 |
+
function dt() {
|
2115 |
+
const s = this[y];
|
2116 |
+
s._readyState = m.CLOSING, s._receiver.end(), this.end();
|
2117 |
+
}
|
2118 |
+
function _t() {
|
2119 |
+
const s = this[y];
|
2120 |
+
this.removeListener("error", _t), this.on("error", at), s && (s._readyState = m.CLOSING, this.destroy());
|
2121 |
+
}
|
2122 |
+
const Xs = /* @__PURE__ */ z(ft), { tokenChars: Ns } = ae;
|
2123 |
+
function Ps(s) {
|
2124 |
+
const e = /* @__PURE__ */ new Set();
|
2125 |
+
let t = -1, r = -1, i = 0;
|
2126 |
+
for (i; i < s.length; i++) {
|
2127 |
+
const o = s.charCodeAt(i);
|
2128 |
+
if (r === -1 && Ns[o] === 1)
|
2129 |
+
t === -1 && (t = i);
|
2130 |
+
else if (i !== 0 && (o === 32 || o === 9))
|
2131 |
+
r === -1 && t !== -1 && (r = i);
|
2132 |
+
else if (o === 44) {
|
2133 |
+
if (t === -1)
|
2134 |
+
throw new SyntaxError(`Unexpected character at index ${i}`);
|
2135 |
+
r === -1 && (r = i);
|
2136 |
+
const l = s.slice(t, r);
|
2137 |
+
if (e.has(l))
|
2138 |
+
throw new SyntaxError(`The "${l}" subprotocol is duplicated`);
|
2139 |
+
e.add(l), t = r = -1;
|
2140 |
+
} else
|
2141 |
+
throw new SyntaxError(`Unexpected character at index ${i}`);
|
2142 |
+
}
|
2143 |
+
if (t === -1 || r !== -1)
|
2144 |
+
throw new SyntaxError("Unexpected end of input");
|
2145 |
+
const n = s.slice(t, i);
|
2146 |
+
if (e.has(n))
|
2147 |
+
throw new SyntaxError(`The "${n}" subprotocol is duplicated`);
|
2148 |
+
return e.add(n), e;
|
2149 |
+
}
|
2150 |
+
var Rs = { parse: Ps };
|
2151 |
+
const Us = S, ie = S, { createHash: Bs } = S, qe = nt, N = oe, $s = Rs, Ms = ft, { GUID: Is, kWebSocket: Ds } = U, Ws = /^[+/0-9A-Za-z]{22}==$/, Ke = 0, Xe = 1, pt = 2;
|
2152 |
+
class As extends Us {
|
2153 |
+
/**
|
2154 |
+
* Create a `WebSocketServer` instance.
|
2155 |
+
*
|
2156 |
+
* @param {Object} options Configuration options
|
2157 |
+
* @param {Number} [options.backlog=511] The maximum length of the queue of
|
2158 |
+
* pending connections
|
2159 |
+
* @param {Boolean} [options.clientTracking=true] Specifies whether or not to
|
2160 |
+
* track clients
|
2161 |
+
* @param {Function} [options.handleProtocols] A hook to handle protocols
|
2162 |
+
* @param {String} [options.host] The hostname where to bind the server
|
2163 |
+
* @param {Number} [options.maxPayload=104857600] The maximum allowed message
|
2164 |
+
* size
|
2165 |
+
* @param {Boolean} [options.noServer=false] Enable no server mode
|
2166 |
+
* @param {String} [options.path] Accept only connections matching this path
|
2167 |
+
* @param {(Boolean|Object)} [options.perMessageDeflate=false] Enable/disable
|
2168 |
+
* permessage-deflate
|
2169 |
+
* @param {Number} [options.port] The port where to bind the server
|
2170 |
+
* @param {(http.Server|https.Server)} [options.server] A pre-created HTTP/S
|
2171 |
+
* server to use
|
2172 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
2173 |
+
* not to skip UTF-8 validation for text and close messages
|
2174 |
+
* @param {Function} [options.verifyClient] A hook to reject connections
|
2175 |
+
* @param {Function} [options.WebSocket=WebSocket] Specifies the `WebSocket`
|
2176 |
+
* class to use. It must be the `WebSocket` class or class that extends it
|
2177 |
+
* @param {Function} [callback] A listener for the `listening` event
|
2178 |
+
*/
|
2179 |
+
constructor(e, t) {
|
2180 |
+
if (super(), e = {
|
2181 |
+
maxPayload: 100 * 1024 * 1024,
|
2182 |
+
skipUTF8Validation: !1,
|
2183 |
+
perMessageDeflate: !1,
|
2184 |
+
handleProtocols: null,
|
2185 |
+
clientTracking: !0,
|
2186 |
+
verifyClient: null,
|
2187 |
+
noServer: !1,
|
2188 |
+
backlog: null,
|
2189 |
+
// use default (511 as implemented in net.js)
|
2190 |
+
server: null,
|
2191 |
+
host: null,
|
2192 |
+
path: null,
|
2193 |
+
port: null,
|
2194 |
+
WebSocket: Ms,
|
2195 |
+
...e
|
2196 |
+
}, e.port == null && !e.server && !e.noServer || e.port != null && (e.server || e.noServer) || e.server && e.noServer)
|
2197 |
+
throw new TypeError(
|
2198 |
+
'One and only one of the "port", "server", or "noServer" options must be specified'
|
2199 |
+
);
|
2200 |
+
if (e.port != null ? (this._server = ie.createServer((r, i) => {
|
2201 |
+
const n = ie.STATUS_CODES[426];
|
2202 |
+
i.writeHead(426, {
|
2203 |
+
"Content-Length": n.length,
|
2204 |
+
"Content-Type": "text/plain"
|
2205 |
+
}), i.end(n);
|
2206 |
+
}), this._server.listen(
|
2207 |
+
e.port,
|
2208 |
+
e.host,
|
2209 |
+
e.backlog,
|
2210 |
+
t
|
2211 |
+
)) : e.server && (this._server = e.server), this._server) {
|
2212 |
+
const r = this.emit.bind(this, "connection");
|
2213 |
+
this._removeListeners = js(this._server, {
|
2214 |
+
listening: this.emit.bind(this, "listening"),
|
2215 |
+
error: this.emit.bind(this, "error"),
|
2216 |
+
upgrade: (i, n, o) => {
|
2217 |
+
this.handleUpgrade(i, n, o, r);
|
2218 |
+
}
|
2219 |
+
});
|
2220 |
+
}
|
2221 |
+
e.perMessageDeflate === !0 && (e.perMessageDeflate = {}), e.clientTracking && (this.clients = /* @__PURE__ */ new Set(), this._shouldEmitClose = !1), this.options = e, this._state = Ke;
|
2222 |
+
}
|
2223 |
+
/**
|
2224 |
+
* Returns the bound address, the address family name, and port of the server
|
2225 |
+
* as reported by the operating system if listening on an IP socket.
|
2226 |
+
* If the server is listening on a pipe or UNIX domain socket, the name is
|
2227 |
+
* returned as a string.
|
2228 |
+
*
|
2229 |
+
* @return {(Object|String|null)} The address of the server
|
2230 |
+
* @public
|
2231 |
+
*/
|
2232 |
+
address() {
|
2233 |
+
if (this.options.noServer)
|
2234 |
+
throw new Error('The server is operating in "noServer" mode');
|
2235 |
+
return this._server ? this._server.address() : null;
|
2236 |
+
}
|
2237 |
+
/**
|
2238 |
+
* Stop the server from accepting new connections and emit the `'close'` event
|
2239 |
+
* when all existing connections are closed.
|
2240 |
+
*
|
2241 |
+
* @param {Function} [cb] A one-time listener for the `'close'` event
|
2242 |
+
* @public
|
2243 |
+
*/
|
2244 |
+
close(e) {
|
2245 |
+
if (this._state === pt) {
|
2246 |
+
e && this.once("close", () => {
|
2247 |
+
e(new Error("The server is not running"));
|
2248 |
+
}), process.nextTick(G, this);
|
2249 |
+
return;
|
2250 |
+
}
|
2251 |
+
if (e && this.once("close", e), this._state !== Xe)
|
2252 |
+
if (this._state = Xe, this.options.noServer || this.options.server)
|
2253 |
+
this._server && (this._removeListeners(), this._removeListeners = this._server = null), this.clients ? this.clients.size ? this._shouldEmitClose = !0 : process.nextTick(G, this) : process.nextTick(G, this);
|
2254 |
+
else {
|
2255 |
+
const t = this._server;
|
2256 |
+
this._removeListeners(), this._removeListeners = this._server = null, t.close(() => {
|
2257 |
+
G(this);
|
2258 |
+
});
|
2259 |
+
}
|
2260 |
+
}
|
2261 |
+
/**
|
2262 |
+
* See if a given request should be handled by this server instance.
|
2263 |
+
*
|
2264 |
+
* @param {http.IncomingMessage} req Request object to inspect
|
2265 |
+
* @return {Boolean} `true` if the request is valid, else `false`
|
2266 |
+
* @public
|
2267 |
+
*/
|
2268 |
+
shouldHandle(e) {
|
2269 |
+
if (this.options.path) {
|
2270 |
+
const t = e.url.indexOf("?");
|
2271 |
+
if ((t !== -1 ? e.url.slice(0, t) : e.url) !== this.options.path)
|
2272 |
+
return !1;
|
2273 |
+
}
|
2274 |
+
return !0;
|
2275 |
+
}
|
2276 |
+
/**
|
2277 |
+
* Handle a HTTP Upgrade request.
|
2278 |
+
*
|
2279 |
+
* @param {http.IncomingMessage} req The request object
|
2280 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
2281 |
+
* server and client
|
2282 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
2283 |
+
* @param {Function} cb Callback
|
2284 |
+
* @public
|
2285 |
+
*/
|
2286 |
+
handleUpgrade(e, t, r, i) {
|
2287 |
+
t.on("error", Ze);
|
2288 |
+
const n = e.headers["sec-websocket-key"], o = +e.headers["sec-websocket-version"];
|
2289 |
+
if (e.method !== "GET") {
|
2290 |
+
R(this, e, t, 405, "Invalid HTTP method");
|
2291 |
+
return;
|
2292 |
+
}
|
2293 |
+
if (e.headers.upgrade.toLowerCase() !== "websocket") {
|
2294 |
+
R(this, e, t, 400, "Invalid Upgrade header");
|
2295 |
+
return;
|
2296 |
+
}
|
2297 |
+
if (!n || !Ws.test(n)) {
|
2298 |
+
R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Key header");
|
2299 |
+
return;
|
2300 |
+
}
|
2301 |
+
if (o !== 8 && o !== 13) {
|
2302 |
+
R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Version header");
|
2303 |
+
return;
|
2304 |
+
}
|
2305 |
+
if (!this.shouldHandle(e)) {
|
2306 |
+
H(t, 400);
|
2307 |
+
return;
|
2308 |
+
}
|
2309 |
+
const l = e.headers["sec-websocket-protocol"];
|
2310 |
+
let f = /* @__PURE__ */ new Set();
|
2311 |
+
if (l !== void 0)
|
2312 |
+
try {
|
2313 |
+
f = $s.parse(l);
|
2314 |
+
} catch {
|
2315 |
+
R(this, e, t, 400, "Invalid Sec-WebSocket-Protocol header");
|
2316 |
+
return;
|
2317 |
+
}
|
2318 |
+
const a = e.headers["sec-websocket-extensions"], c = {};
|
2319 |
+
if (this.options.perMessageDeflate && a !== void 0) {
|
2320 |
+
const h = new N(
|
2321 |
+
this.options.perMessageDeflate,
|
2322 |
+
!0,
|
2323 |
+
this.options.maxPayload
|
2324 |
+
);
|
2325 |
+
try {
|
2326 |
+
const p = qe.parse(a);
|
2327 |
+
p[N.extensionName] && (h.accept(p[N.extensionName]), c[N.extensionName] = h);
|
2328 |
+
} catch {
|
2329 |
+
R(this, e, t, 400, "Invalid or unacceptable Sec-WebSocket-Extensions header");
|
2330 |
+
return;
|
2331 |
+
}
|
2332 |
+
}
|
2333 |
+
if (this.options.verifyClient) {
|
2334 |
+
const h = {
|
2335 |
+
origin: e.headers[`${o === 8 ? "sec-websocket-origin" : "origin"}`],
|
2336 |
+
secure: !!(e.socket.authorized || e.socket.encrypted),
|
2337 |
+
req: e
|
2338 |
+
};
|
2339 |
+
if (this.options.verifyClient.length === 2) {
|
2340 |
+
this.options.verifyClient(h, (p, v, _, u) => {
|
2341 |
+
if (!p)
|
2342 |
+
return H(t, v || 401, _, u);
|
2343 |
+
this.completeUpgrade(
|
2344 |
+
c,
|
2345 |
+
n,
|
2346 |
+
f,
|
2347 |
+
e,
|
2348 |
+
t,
|
2349 |
+
r,
|
2350 |
+
i
|
2351 |
+
);
|
2352 |
+
});
|
2353 |
+
return;
|
2354 |
+
}
|
2355 |
+
if (!this.options.verifyClient(h))
|
2356 |
+
return H(t, 401);
|
2357 |
+
}
|
2358 |
+
this.completeUpgrade(c, n, f, e, t, r, i);
|
2359 |
+
}
|
2360 |
+
/**
|
2361 |
+
* Upgrade the connection to WebSocket.
|
2362 |
+
*
|
2363 |
+
* @param {Object} extensions The accepted extensions
|
2364 |
+
* @param {String} key The value of the `Sec-WebSocket-Key` header
|
2365 |
+
* @param {Set} protocols The subprotocols
|
2366 |
+
* @param {http.IncomingMessage} req The request object
|
2367 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
2368 |
+
* server and client
|
2369 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
2370 |
+
* @param {Function} cb Callback
|
2371 |
+
* @throws {Error} If called more than once with the same socket
|
2372 |
+
* @private
|
2373 |
+
*/
|
2374 |
+
completeUpgrade(e, t, r, i, n, o, l) {
|
2375 |
+
if (!n.readable || !n.writable)
|
2376 |
+
return n.destroy();
|
2377 |
+
if (n[Ds])
|
2378 |
+
throw new Error(
|
2379 |
+
"server.handleUpgrade() was called more than once with the same socket, possibly due to a misconfiguration"
|
2380 |
+
);
|
2381 |
+
if (this._state > Ke)
|
2382 |
+
return H(n, 503);
|
2383 |
+
const a = [
|
2384 |
+
"HTTP/1.1 101 Switching Protocols",
|
2385 |
+
"Upgrade: websocket",
|
2386 |
+
"Connection: Upgrade",
|
2387 |
+
`Sec-WebSocket-Accept: ${Bs("sha1").update(t + Is).digest("base64")}`
|
2388 |
+
], c = new this.options.WebSocket(null);
|
2389 |
+
if (r.size) {
|
2390 |
+
const h = this.options.handleProtocols ? this.options.handleProtocols(r, i) : r.values().next().value;
|
2391 |
+
h && (a.push(`Sec-WebSocket-Protocol: ${h}`), c._protocol = h);
|
2392 |
+
}
|
2393 |
+
if (e[N.extensionName]) {
|
2394 |
+
const h = e[N.extensionName].params, p = qe.format({
|
2395 |
+
[N.extensionName]: [h]
|
2396 |
+
});
|
2397 |
+
a.push(`Sec-WebSocket-Extensions: ${p}`), c._extensions = e;
|
2398 |
+
}
|
2399 |
+
this.emit("headers", a, i), n.write(a.concat(`\r
|
2400 |
+
`).join(`\r
|
2401 |
+
`)), n.removeListener("error", Ze), c.setSocket(n, o, {
|
2402 |
+
maxPayload: this.options.maxPayload,
|
2403 |
+
skipUTF8Validation: this.options.skipUTF8Validation
|
2404 |
+
}), this.clients && (this.clients.add(c), c.on("close", () => {
|
2405 |
+
this.clients.delete(c), this._shouldEmitClose && !this.clients.size && process.nextTick(G, this);
|
2406 |
+
})), l(c, i);
|
2407 |
+
}
|
2408 |
+
}
|
2409 |
+
var Fs = As;
|
2410 |
+
function js(s, e) {
|
2411 |
+
for (const t of Object.keys(e))
|
2412 |
+
s.on(t, e[t]);
|
2413 |
+
return function() {
|
2414 |
+
for (const r of Object.keys(e))
|
2415 |
+
s.removeListener(r, e[r]);
|
2416 |
+
};
|
2417 |
+
}
|
2418 |
+
function G(s) {
|
2419 |
+
s._state = pt, s.emit("close");
|
2420 |
+
}
|
2421 |
+
function Ze() {
|
2422 |
+
this.destroy();
|
2423 |
+
}
|
2424 |
+
function H(s, e, t, r) {
|
2425 |
+
t = t || ie.STATUS_CODES[e], r = {
|
2426 |
+
Connection: "close",
|
2427 |
+
"Content-Type": "text/html",
|
2428 |
+
"Content-Length": Buffer.byteLength(t),
|
2429 |
+
...r
|
2430 |
+
}, s.once("finish", s.destroy), s.end(
|
2431 |
+
`HTTP/1.1 ${e} ${ie.STATUS_CODES[e]}\r
|
2432 |
+
` + Object.keys(r).map((i) => `${i}: ${r[i]}`).join(`\r
|
2433 |
+
`) + `\r
|
2434 |
+
\r
|
2435 |
+
` + t
|
2436 |
+
);
|
2437 |
+
}
|
2438 |
+
function R(s, e, t, r, i) {
|
2439 |
+
if (s.listenerCount("wsClientError")) {
|
2440 |
+
const n = new Error(i);
|
2441 |
+
Error.captureStackTrace(n, R), s.emit("wsClientError", n, t, e);
|
2442 |
+
} else
|
2443 |
+
H(t, r, i);
|
2444 |
+
}
|
2445 |
+
const Zs = /* @__PURE__ */ z(Fs);
|
2446 |
+
export {
|
2447 |
+
qs as Receiver,
|
2448 |
+
Ks as Sender,
|
2449 |
+
Xs as WebSocket,
|
2450 |
+
Zs as WebSocketServer,
|
2451 |
+
Vs as createWebSocketStream,
|
2452 |
+
Xs as default
|
2453 |
+
};
|
gradio_dualvision/gradio_patches/templates/example/index.js
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const {
|
2 |
+
SvelteComponent: y,
|
3 |
+
append: m,
|
4 |
+
attr: n,
|
5 |
+
detach: b,
|
6 |
+
element: u,
|
7 |
+
init: k,
|
8 |
+
insert: p,
|
9 |
+
noop: w,
|
10 |
+
safe_not_equal: h,
|
11 |
+
space: o,
|
12 |
+
src_url_equal: v,
|
13 |
+
toggle_class: f
|
14 |
+
} = window.__gradio__svelte__internal;
|
15 |
+
function q(a) {
|
16 |
+
let e, l, _, g, i, c, s, d;
|
17 |
+
return {
|
18 |
+
c() {
|
19 |
+
e = u("div"), l = u("img"), g = o(), i = u("img"), s = o(), d = u("span"), v(l.src, _ = /*samples_dir*/
|
20 |
+
a[1] + /*value*/
|
21 |
+
a[0][0]) || n(l, "src", _), n(l, "class", "svelte-l4wpk0"), v(i.src, c = /*samples_dir*/
|
22 |
+
a[1] + /*value*/
|
23 |
+
a[0][1]) || n(i, "src", c), n(i, "class", "svelte-l4wpk0"), n(d, "class", "svelte-l4wpk0"), n(e, "class", "wrap svelte-l4wpk0"), f(
|
24 |
+
e,
|
25 |
+
"table",
|
26 |
+
/*type*/
|
27 |
+
a[2] === "table"
|
28 |
+
), f(
|
29 |
+
e,
|
30 |
+
"gallery",
|
31 |
+
/*type*/
|
32 |
+
a[2] === "gallery"
|
33 |
+
), f(
|
34 |
+
e,
|
35 |
+
"selected",
|
36 |
+
/*selected*/
|
37 |
+
a[3]
|
38 |
+
);
|
39 |
+
},
|
40 |
+
m(t, r) {
|
41 |
+
p(t, e, r), m(e, l), m(e, g), m(e, i), m(e, s), m(e, d);
|
42 |
+
},
|
43 |
+
p(t, [r]) {
|
44 |
+
r & /*samples_dir, value*/
|
45 |
+
3 && !v(l.src, _ = /*samples_dir*/
|
46 |
+
t[1] + /*value*/
|
47 |
+
t[0][0]) && n(l, "src", _), r & /*samples_dir, value*/
|
48 |
+
3 && !v(i.src, c = /*samples_dir*/
|
49 |
+
t[1] + /*value*/
|
50 |
+
t[0][1]) && n(i, "src", c), r & /*type*/
|
51 |
+
4 && f(
|
52 |
+
e,
|
53 |
+
"table",
|
54 |
+
/*type*/
|
55 |
+
t[2] === "table"
|
56 |
+
), r & /*type*/
|
57 |
+
4 && f(
|
58 |
+
e,
|
59 |
+
"gallery",
|
60 |
+
/*type*/
|
61 |
+
t[2] === "gallery"
|
62 |
+
), r & /*selected*/
|
63 |
+
8 && f(
|
64 |
+
e,
|
65 |
+
"selected",
|
66 |
+
/*selected*/
|
67 |
+
t[3]
|
68 |
+
);
|
69 |
+
},
|
70 |
+
i: w,
|
71 |
+
o: w,
|
72 |
+
d(t) {
|
73 |
+
t && b(e);
|
74 |
+
}
|
75 |
+
};
|
76 |
+
}
|
77 |
+
function I(a, e, l) {
|
78 |
+
let { value: _ } = e, { samples_dir: g } = e, { type: i } = e, { selected: c = !1 } = e;
|
79 |
+
return a.$$set = (s) => {
|
80 |
+
"value" in s && l(0, _ = s.value), "samples_dir" in s && l(1, g = s.samples_dir), "type" in s && l(2, i = s.type), "selected" in s && l(3, c = s.selected);
|
81 |
+
}, [_, g, i, c];
|
82 |
+
}
|
83 |
+
class C extends y {
|
84 |
+
constructor(e) {
|
85 |
+
super(), k(this, e, I, q, h, {
|
86 |
+
value: 0,
|
87 |
+
samples_dir: 1,
|
88 |
+
type: 2,
|
89 |
+
selected: 3
|
90 |
+
});
|
91 |
+
}
|
92 |
+
}
|
93 |
+
export {
|
94 |
+
C as default
|
95 |
+
};
|
gradio_dualvision/gradio_patches/templates/example/style.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.wrap.svelte-l4wpk0.svelte-l4wpk0{position:relative;height:var(--size-64);width:var(--size-40);overflow:hidden;border-radius:var(--radius-lg)}img.svelte-l4wpk0.svelte-l4wpk0{height:var(--size-64);width:var(--size-40);position:absolute;object-fit:cover}.wrap.selected.svelte-l4wpk0.svelte-l4wpk0{border-color:var(--color-accent)}.wrap.svelte-l4wpk0 img.svelte-l4wpk0:first-child{clip-path:inset(0 50% 0 0%)}.wrap.svelte-l4wpk0 img.svelte-l4wpk0:nth-of-type(2){clip-path:inset(0 0 0 50%)}span.svelte-l4wpk0.svelte-l4wpk0{position:absolute;top:0;left:calc(50% - .75px);height:var(--size-64);width:1.5px;background:var(--border-color-primary)}.table.svelte-l4wpk0.svelte-l4wpk0{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg)}.gallery.svelte-l4wpk0.svelte-l4wpk0{border:2px solid var(--border-color-primary);object-fit:cover}
|
gradio_dualvision/version.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
2 |
+
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
|
3 |
+
# See https://creativecommons.org/licenses/by-sa/4.0/ for details.
|
4 |
+
# --------------------------------------------------------------------------
|
5 |
+
# DualVision is a Gradio template app for image processing. It was developed
|
6 |
+
# to support the Marigold project. If you find this code useful, we kindly
|
7 |
+
# ask you to cite our most relevant papers.
|
8 |
+
# More information about Marigold:
|
9 |
+
# https://marigoldmonodepth.github.io
|
10 |
+
# https://marigoldcomputervision.github.io
|
11 |
+
# Efficient inference pipelines are now part of diffusers:
|
12 |
+
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
|
13 |
+
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
|
14 |
+
# Examples of trained models and live demos:
|
15 |
+
# https://huggingface.co/prs-eth
|
16 |
+
# Related projects:
|
17 |
+
# https://marigolddepthcompletion.github.io/
|
18 |
+
# https://rollingdepth.github.io/
|
19 |
+
# Citation (BibTeX):
|
20 |
+
# https://github.com/prs-eth/Marigold#-citation
|
21 |
+
# https://github.com/prs-eth/Marigold-DC#-citation
|
22 |
+
# https://github.com/prs-eth/rollingdepth#-citation
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
|
25 |
+
__version__ = "0.1.0"
|
model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .hyper import Hypernetwork
|
2 |
+
from .thera import build_thera
|
model/convnext.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax.linen as nn
|
2 |
+
from jaxtyping import Array, ArrayLike
|
3 |
+
|
4 |
+
|
5 |
+
class ConvNeXtBlock(nn.Module):
|
6 |
+
"""ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
|
7 |
+
|
8 |
+
https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
|
9 |
+
"""
|
10 |
+
n_dims: int = 64
|
11 |
+
kernel_size: int = 3 # 7 in the paper's version
|
12 |
+
group_features: bool = False
|
13 |
+
|
14 |
+
def setup(self) -> None:
|
15 |
+
self.residual = nn.Sequential([
|
16 |
+
nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
|
17 |
+
feature_group_count=self.n_dims if self.group_features else 1),
|
18 |
+
nn.LayerNorm(),
|
19 |
+
nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
|
20 |
+
nn.gelu,
|
21 |
+
nn.Conv(self.n_dims, kernel_size=(1, 1)),
|
22 |
+
])
|
23 |
+
|
24 |
+
def __call__(self, x: ArrayLike) -> Array:
|
25 |
+
return x + self.residual(x)
|
26 |
+
|
27 |
+
|
28 |
+
class Projection(nn.Module):
|
29 |
+
n_dims: int
|
30 |
+
|
31 |
+
@nn.compact
|
32 |
+
def __call__(self, x: ArrayLike) -> Array:
|
33 |
+
x = nn.LayerNorm()(x)
|
34 |
+
x = nn.Conv(self.n_dims, (1, 1))(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class ConvNeXt(nn.Module):
|
39 |
+
block_defs: list[tuple]
|
40 |
+
|
41 |
+
def setup(self) -> None:
|
42 |
+
layers = []
|
43 |
+
current_size = self.block_defs[0][0]
|
44 |
+
for block_def in self.block_defs:
|
45 |
+
if block_def[0] != current_size:
|
46 |
+
layers.append(Projection(block_def[0]))
|
47 |
+
layers.append(ConvNeXtBlock(*block_def))
|
48 |
+
current_size = block_def[0]
|
49 |
+
self.layers = layers
|
50 |
+
|
51 |
+
def __call__(self, x: ArrayLike, _: bool) -> Array:
|
52 |
+
for layer in self.layers:
|
53 |
+
x = layer(x)
|
54 |
+
return x
|
55 |
+
|
model/edsr.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://github.com/isaaccorley/jax-enhance
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Sequence, Callable
|
5 |
+
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import flax.linen as nn
|
8 |
+
from flax.core.frozen_dict import freeze
|
9 |
+
import einops
|
10 |
+
|
11 |
+
|
12 |
+
class PixelShuffle(nn.Module):
|
13 |
+
scale_factor: int
|
14 |
+
|
15 |
+
def setup(self):
|
16 |
+
self.layer = partial(
|
17 |
+
einops.rearrange,
|
18 |
+
pattern="b h w (c h2 w2) -> b (h h2) (w w2) c",
|
19 |
+
h2=self.scale_factor,
|
20 |
+
w2=self.scale_factor
|
21 |
+
)
|
22 |
+
|
23 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
24 |
+
return self.layer(x)
|
25 |
+
|
26 |
+
|
27 |
+
class ResidualBlock(nn.Module):
|
28 |
+
channels: int
|
29 |
+
kernel_size: Sequence[int]
|
30 |
+
res_scale: float
|
31 |
+
activation: Callable
|
32 |
+
dtype: Any = jnp.float32
|
33 |
+
|
34 |
+
def setup(self):
|
35 |
+
self.body = nn.Sequential([
|
36 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
37 |
+
self.activation,
|
38 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
39 |
+
])
|
40 |
+
|
41 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
42 |
+
return x + self.body(x)
|
43 |
+
|
44 |
+
|
45 |
+
class UpsampleBlock(nn.Module):
|
46 |
+
num_upsamples: int
|
47 |
+
channels: int
|
48 |
+
kernel_size: Sequence[int]
|
49 |
+
dtype: Any = jnp.float32
|
50 |
+
|
51 |
+
def setup(self):
|
52 |
+
layers = []
|
53 |
+
for _ in range(self.num_upsamples):
|
54 |
+
layers.extend([
|
55 |
+
nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype),
|
56 |
+
PixelShuffle(scale_factor=2),
|
57 |
+
])
|
58 |
+
self.layers = layers
|
59 |
+
|
60 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
61 |
+
for layer in self.layers:
|
62 |
+
x = layer(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class EDSR(nn.Module):
|
67 |
+
"""Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf"""
|
68 |
+
scale_factor: int
|
69 |
+
channels: int = 3
|
70 |
+
num_blocks: int = 32
|
71 |
+
num_feats: int = 256
|
72 |
+
dtype: Any = jnp.float32
|
73 |
+
|
74 |
+
def setup(self):
|
75 |
+
# pre res blocks layer
|
76 |
+
self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)])
|
77 |
+
|
78 |
+
# res blocks
|
79 |
+
res_blocks = [
|
80 |
+
ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype)
|
81 |
+
for i in range(self.num_blocks)
|
82 |
+
]
|
83 |
+
res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype))
|
84 |
+
self.body = nn.Sequential(res_blocks)
|
85 |
+
|
86 |
+
def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray:
|
87 |
+
x = self.head(x)
|
88 |
+
x = x + self.body(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
def convert_edsr_checkpoint(torch_dict, no_upsampling=True):
|
93 |
+
def convert(in_dict):
|
94 |
+
top_keys = set([k.split('.')[0] for k in in_dict.keys()])
|
95 |
+
leaves = set([k for k in in_dict.keys() if '.' not in k])
|
96 |
+
|
97 |
+
# convert leaves
|
98 |
+
out_dict = {}
|
99 |
+
for l in leaves:
|
100 |
+
if l == 'weight':
|
101 |
+
out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0))
|
102 |
+
elif l == 'bias':
|
103 |
+
out_dict[l] = jnp.asarray(in_dict[l])
|
104 |
+
else:
|
105 |
+
out_dict[l] = in_dict[l]
|
106 |
+
|
107 |
+
for top_key in top_keys.difference(leaves):
|
108 |
+
new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key
|
109 |
+
out_dict[new_top_key] = convert(
|
110 |
+
{k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)})
|
111 |
+
return out_dict
|
112 |
+
|
113 |
+
converted = convert(torch_dict)
|
114 |
+
|
115 |
+
# remove unwanted keys
|
116 |
+
if no_upsampling:
|
117 |
+
del converted['tail']
|
118 |
+
|
119 |
+
for k in ('add_mean', 'sub_mean'):
|
120 |
+
del converted[k]
|
121 |
+
|
122 |
+
return freeze(converted)
|
model/hyper.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
from jaxtyping import Array, ArrayLike, PyTreeDef
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from utils import interpolate_grid
|
10 |
+
|
11 |
+
|
12 |
+
class Hypernetwork(nn.Module):
|
13 |
+
encoder: nn.Module
|
14 |
+
refine: nn.Module
|
15 |
+
output_params_shape: list[tuple] # e.g. [(16,), (32, 32), ...]
|
16 |
+
tree_def: PyTreeDef # used to reconstruct the parameter sets
|
17 |
+
|
18 |
+
def setup(self):
|
19 |
+
# one layer 1x1 conv to calculate field params, as in SIREN paper
|
20 |
+
output_size = sum(math.prod(s) for s in self.output_params_shape)
|
21 |
+
self.out_conv = nn.Conv(output_size, kernel_size=(1, 1), use_bias=True)
|
22 |
+
|
23 |
+
def get_encoding(self, source: ArrayLike, training=False) -> Array:
|
24 |
+
"""Convenience method for whole-image evaluation"""
|
25 |
+
return self.refine(self.encoder(source, training), training)
|
26 |
+
|
27 |
+
def get_params_at_coords(self, encoding: ArrayLike, coords: ArrayLike) -> Array:
|
28 |
+
encoding = interpolate_grid(coords, encoding)
|
29 |
+
phi_params = self.out_conv(encoding)
|
30 |
+
|
31 |
+
# reshape to output params shape
|
32 |
+
phi_params = jnp.split(
|
33 |
+
phi_params, np.cumsum([math.prod(s) for s in self.output_params_shape[:-1]]), axis=-1)
|
34 |
+
phi_params = [jnp.reshape(p, p.shape[:-1] + s) for p, s in
|
35 |
+
zip(phi_params, self.output_params_shape)]
|
36 |
+
|
37 |
+
return jax.tree_util.tree_unflatten(self.tree_def, phi_params)
|
38 |
+
|
39 |
+
def __call__(self, source: ArrayLike, target_coords: ArrayLike, training=False) -> Array:
|
40 |
+
encoding = self.get_encoding(source, training)
|
41 |
+
return self.get_params_at_coords(encoding, target_coords)
|
model/init.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from jaxtyping import Array
|
6 |
+
|
7 |
+
|
8 |
+
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
|
9 |
+
def init(key, shape, dtype=dtype) -> Array:
|
10 |
+
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
|
11 |
+
return init
|
12 |
+
|
13 |
+
|
14 |
+
def linear_up(scale: float) -> Callable:
|
15 |
+
def init(key, shape, dtype=jnp.float32) -> Array:
|
16 |
+
assert shape[-2] == 2
|
17 |
+
keys = jax.random.split(key, 2)
|
18 |
+
norm = jnp.pi * scale * (
|
19 |
+
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
|
20 |
+
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
|
21 |
+
x = norm * jnp.cos(theta)
|
22 |
+
y = norm * jnp.sin(theta)
|
23 |
+
return jnp.concatenate([x, y], axis=-2).astype(dtype)
|
24 |
+
return init
|
model/rdn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Residual Dense Network for Image Super-Resolution
|
2 |
+
# https://arxiv.org/abs/1802.08797
|
3 |
+
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch
|
4 |
+
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
|
8 |
+
|
9 |
+
class RDB_Conv(nn.Module):
|
10 |
+
growRate: int
|
11 |
+
kSize: int = 3
|
12 |
+
|
13 |
+
@nn.compact
|
14 |
+
def __call__(self, x):
|
15 |
+
out = nn.Sequential([
|
16 |
+
nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
|
17 |
+
nn.activation.relu
|
18 |
+
])(x)
|
19 |
+
return jnp.concatenate((x, out), -1)
|
20 |
+
|
21 |
+
|
22 |
+
class RDB(nn.Module):
|
23 |
+
growRate0: int
|
24 |
+
growRate: int
|
25 |
+
nConvLayers: int
|
26 |
+
|
27 |
+
@nn.compact
|
28 |
+
def __call__(self, x):
|
29 |
+
res = x
|
30 |
+
|
31 |
+
for c in range(self.nConvLayers):
|
32 |
+
x = RDB_Conv(self.growRate)(x)
|
33 |
+
|
34 |
+
x = nn.Conv(self.growRate0, (1, 1))(x)
|
35 |
+
|
36 |
+
return x + res
|
37 |
+
|
38 |
+
|
39 |
+
class RDN(nn.Module):
|
40 |
+
G0: int = 64
|
41 |
+
RDNkSize: int = 3
|
42 |
+
RDNconfig: str = 'B'
|
43 |
+
scale: int = 2
|
44 |
+
n_colors: int = 3
|
45 |
+
|
46 |
+
@nn.compact
|
47 |
+
def __call__(self, x, _=None):
|
48 |
+
D, C, G = {
|
49 |
+
'A': (20, 6, 32),
|
50 |
+
'B': (16, 8, 64),
|
51 |
+
}[self.RDNconfig]
|
52 |
+
|
53 |
+
# Shallow feature extraction
|
54 |
+
f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
|
55 |
+
x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
|
56 |
+
|
57 |
+
# Redidual dense blocks and dense feature fusion
|
58 |
+
RDBs_out = []
|
59 |
+
for i in range(D):
|
60 |
+
x = RDB(self.G0, G, C)(x)
|
61 |
+
RDBs_out.append(x)
|
62 |
+
|
63 |
+
x = jnp.concatenate(RDBs_out, -1)
|
64 |
+
|
65 |
+
# Global Feature Fusion
|
66 |
+
x = nn.Sequential([
|
67 |
+
nn.Conv(self.G0, (1, 1)),
|
68 |
+
nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
|
69 |
+
])(x)
|
70 |
+
|
71 |
+
x = x + f_1
|
72 |
+
return x
|
model/swin_ir.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, Optional, Iterable
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import flax.linen as nn
|
8 |
+
from jaxtyping import Array
|
9 |
+
|
10 |
+
|
11 |
+
def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
|
12 |
+
"""Truncated normal initialization function"""
|
13 |
+
|
14 |
+
def init(key, shape, dtype=dtype) -> Array:
|
15 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
|
16 |
+
def norm_cdf(x):
|
17 |
+
# Computes standard normal cumulative distribution function
|
18 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
19 |
+
|
20 |
+
l = norm_cdf((a - mean) / std)
|
21 |
+
u = norm_cdf((b - mean) / std)
|
22 |
+
out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
|
23 |
+
out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
|
24 |
+
return jnp.clip(out, a, b)
|
25 |
+
|
26 |
+
return init
|
27 |
+
|
28 |
+
|
29 |
+
def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
|
30 |
+
return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
|
31 |
+
|
32 |
+
|
33 |
+
def LayerNorm():
|
34 |
+
"""torch LayerNorm uses larger epsilon by default"""
|
35 |
+
return nn.LayerNorm(epsilon=1e-05)
|
36 |
+
|
37 |
+
|
38 |
+
class Mlp(nn.Module):
|
39 |
+
|
40 |
+
in_features: int
|
41 |
+
hidden_features: int = None
|
42 |
+
out_features: int = None
|
43 |
+
act_layer: Callable = nn.gelu
|
44 |
+
drop: float = 0.0
|
45 |
+
|
46 |
+
@nn.compact
|
47 |
+
def __call__(self, x, training: bool):
|
48 |
+
x = nn.Dense(self.hidden_features or self.in_features)(x)
|
49 |
+
x = self.act_layer(x)
|
50 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
51 |
+
x = nn.Dense(self.out_features or self.in_features)(x)
|
52 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def window_partition(x, window_size: int):
|
57 |
+
"""
|
58 |
+
Args:
|
59 |
+
x: (B, H, W, C)
|
60 |
+
window_size (int): window size
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
windows: (num_windows*B, window_size, window_size, C)
|
64 |
+
"""
|
65 |
+
B, H, W, C = x.shape
|
66 |
+
x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
|
67 |
+
windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
|
68 |
+
return windows
|
69 |
+
|
70 |
+
|
71 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
windows: (num_windows*B, window_size, window_size, C)
|
75 |
+
window_size (int): Window size
|
76 |
+
H (int): Height of image
|
77 |
+
W (int): Width of image
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
x: (B, H, W, C)
|
81 |
+
"""
|
82 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
83 |
+
x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
|
84 |
+
x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class DropPath(nn.Module):
|
89 |
+
"""
|
90 |
+
Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
91 |
+
"""
|
92 |
+
|
93 |
+
dropout_prob: float = 0.1
|
94 |
+
deterministic: Optional[bool] = None
|
95 |
+
|
96 |
+
@nn.compact
|
97 |
+
def __call__(self, input, training):
|
98 |
+
if not training:
|
99 |
+
return input
|
100 |
+
keep_prob = 1 - self.dropout_prob
|
101 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
|
102 |
+
rng = self.make_rng("dropout")
|
103 |
+
random_tensor = keep_prob + jax.random.uniform(rng, shape)
|
104 |
+
random_tensor = jnp.floor(random_tensor)
|
105 |
+
return jnp.divide(input, keep_prob) * random_tensor
|
106 |
+
|
107 |
+
|
108 |
+
class WindowAttention(nn.Module):
|
109 |
+
dim: int
|
110 |
+
window_size: Iterable[int]
|
111 |
+
num_heads: int
|
112 |
+
qkv_bias: bool = True
|
113 |
+
qk_scale: Optional[float] = None
|
114 |
+
att_drop: float = 0.0
|
115 |
+
proj_drop: float = 0.0
|
116 |
+
|
117 |
+
def make_rel_pos_index(self):
|
118 |
+
h_indices = np.arange(0, self.window_size[0])
|
119 |
+
w_indices = np.arange(0, self.window_size[1])
|
120 |
+
indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
|
121 |
+
flatten_indices = np.reshape(indices, (2, -1))
|
122 |
+
relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
|
123 |
+
relative_indices = np.transpose(relative_indices, (1, 2, 0))
|
124 |
+
relative_indices[:, :, 0] += self.window_size[0] - 1
|
125 |
+
relative_indices[:, :, 1] += self.window_size[1] - 1
|
126 |
+
relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
|
127 |
+
relative_pos_index = np.sum(relative_indices, -1)
|
128 |
+
return relative_pos_index
|
129 |
+
|
130 |
+
@nn.compact
|
131 |
+
def __call__(self, inputs, mask, training):
|
132 |
+
rpbt = self.param(
|
133 |
+
"relative_position_bias_table",
|
134 |
+
trunc_normal(std=.02),
|
135 |
+
(
|
136 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
137 |
+
self.num_heads,
|
138 |
+
),
|
139 |
+
)
|
140 |
+
|
141 |
+
#relative_pos_index = self.variable(
|
142 |
+
# "variables", "relative_position_index", self.get_rel_pos_index
|
143 |
+
#)
|
144 |
+
|
145 |
+
batch, n, channels = inputs.shape
|
146 |
+
qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
|
147 |
+
qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
|
148 |
+
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
|
149 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
150 |
+
|
151 |
+
scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
|
152 |
+
q = q * scale
|
153 |
+
att = q @ jnp.swapaxes(k, -2, -1)
|
154 |
+
|
155 |
+
rel_pos_bias = jnp.reshape(
|
156 |
+
rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
|
157 |
+
(
|
158 |
+
self.window_size[0] * self.window_size[1],
|
159 |
+
self.window_size[0] * self.window_size[1],
|
160 |
+
-1,
|
161 |
+
),
|
162 |
+
)
|
163 |
+
rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
|
164 |
+
att += jnp.expand_dims(rel_pos_bias, 0)
|
165 |
+
|
166 |
+
if mask is not None:
|
167 |
+
att = jnp.reshape(
|
168 |
+
att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
|
169 |
+
)
|
170 |
+
att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
|
171 |
+
att = jnp.reshape(att, (-1, self.num_heads, n, n))
|
172 |
+
att = jax.nn.softmax(att)
|
173 |
+
|
174 |
+
else:
|
175 |
+
att = jax.nn.softmax(att)
|
176 |
+
|
177 |
+
att = nn.Dropout(self.att_drop)(att, deterministic=not training)
|
178 |
+
|
179 |
+
x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
|
180 |
+
x = nn.Dense(self.dim, name="proj")(x)
|
181 |
+
x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class SwinTransformerBlock(nn.Module):
|
186 |
+
|
187 |
+
dim: int
|
188 |
+
input_resolution: tuple[int]
|
189 |
+
num_heads: int
|
190 |
+
window_size: int = 7
|
191 |
+
shift_size: int = 0
|
192 |
+
mlp_ratio: float = 4.
|
193 |
+
qkv_bias: bool = True
|
194 |
+
qk_scale: Optional[float] = None
|
195 |
+
drop: float = 0.
|
196 |
+
attn_drop: float = 0.
|
197 |
+
drop_path: float = 0.
|
198 |
+
act_layer: Callable = nn.activation.gelu
|
199 |
+
norm_layer: Callable = LayerNorm
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def make_att_mask(shift_size, window_size, height, width):
|
203 |
+
if shift_size > 0:
|
204 |
+
mask = jnp.zeros([1, height, width, 1])
|
205 |
+
h_slices = (
|
206 |
+
slice(0, -window_size),
|
207 |
+
slice(-window_size, -shift_size),
|
208 |
+
slice(-shift_size, None),
|
209 |
+
)
|
210 |
+
w_slices = (
|
211 |
+
slice(0, -window_size),
|
212 |
+
slice(-window_size, -shift_size),
|
213 |
+
slice(-shift_size, None),
|
214 |
+
)
|
215 |
+
|
216 |
+
count = 0
|
217 |
+
for h in h_slices:
|
218 |
+
for w in w_slices:
|
219 |
+
mask = mask.at[:, h, w, :].set(count)
|
220 |
+
count += 1
|
221 |
+
|
222 |
+
mask_windows = window_partition(mask, window_size)
|
223 |
+
mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
|
224 |
+
att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
|
225 |
+
att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
|
226 |
+
att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
|
227 |
+
else:
|
228 |
+
att_mask = None
|
229 |
+
|
230 |
+
return att_mask
|
231 |
+
|
232 |
+
@nn.compact
|
233 |
+
def __call__(self, x, x_size, training):
|
234 |
+
H, W = x_size
|
235 |
+
B, L, C = x.shape
|
236 |
+
|
237 |
+
if min(self.input_resolution) <= self.window_size:
|
238 |
+
# if window size is larger than input resolution, we don't partition windows
|
239 |
+
self.shift_size = 0
|
240 |
+
self.window_size = min(self.input_resolution)
|
241 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
242 |
+
|
243 |
+
shortcut = x
|
244 |
+
x = self.norm_layer()(x)
|
245 |
+
x = x.reshape((B, H, W, C))
|
246 |
+
|
247 |
+
# cyclic shift
|
248 |
+
if self.shift_size > 0:
|
249 |
+
shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
|
250 |
+
else:
|
251 |
+
shifted_x = x
|
252 |
+
|
253 |
+
# partition windows
|
254 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
255 |
+
x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
|
256 |
+
|
257 |
+
#attn_mask = self.variable(
|
258 |
+
# "variables",
|
259 |
+
# "attn_mask",
|
260 |
+
# self.get_att_mask,
|
261 |
+
# self.shift_size,
|
262 |
+
# self.window_size,
|
263 |
+
# self.input_resolution[0],
|
264 |
+
# self.input_resolution[1]
|
265 |
+
#)
|
266 |
+
|
267 |
+
attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
|
268 |
+
|
269 |
+
attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
|
270 |
+
self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
|
271 |
+
if self.input_resolution == x_size:
|
272 |
+
attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
|
273 |
+
else:
|
274 |
+
# test time
|
275 |
+
assert not training
|
276 |
+
test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
|
277 |
+
attn_windows = attn(x_windows, test_mask, training=False)
|
278 |
+
|
279 |
+
# merge windows
|
280 |
+
attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
|
281 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
282 |
+
|
283 |
+
# reverse cyclic shift
|
284 |
+
if self.shift_size > 0:
|
285 |
+
x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
|
286 |
+
else:
|
287 |
+
x = shifted_x
|
288 |
+
|
289 |
+
x = x.reshape((B, H * W, C))
|
290 |
+
|
291 |
+
# FFN
|
292 |
+
x = shortcut + DropPath(self.drop_path)(x, training)
|
293 |
+
|
294 |
+
norm = self.norm_layer()(x)
|
295 |
+
mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
|
296 |
+
act_layer=self.act_layer, drop=self.drop)(norm, training)
|
297 |
+
x = x + DropPath(self.drop_path)(mlp, training)
|
298 |
+
|
299 |
+
return x
|
300 |
+
|
301 |
+
|
302 |
+
class PatchMerging(nn.Module):
|
303 |
+
inp_res: Iterable[int]
|
304 |
+
dim: int
|
305 |
+
norm_layer: Callable = LayerNorm
|
306 |
+
|
307 |
+
@nn.compact
|
308 |
+
def __call__(self, inputs):
|
309 |
+
batch, n, channels = inputs.shape
|
310 |
+
height, width = self.inp_res[0], self.inp_res[1]
|
311 |
+
x = jnp.reshape(inputs, (batch, height, width, channels))
|
312 |
+
|
313 |
+
x0 = x[:, 0::2, 0::2, :]
|
314 |
+
x1 = x[:, 1::2, 0::2, :]
|
315 |
+
x2 = x[:, 0::2, 1::2, :]
|
316 |
+
x3 = x[:, 1::2, 1::2, :]
|
317 |
+
|
318 |
+
x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
|
319 |
+
x = jnp.reshape(x, (batch, -1, 4 * channels))
|
320 |
+
x = self.norm_layer()(x)
|
321 |
+
x = nn.Dense(2 * self.dim, use_bias=False)(x)
|
322 |
+
return x
|
323 |
+
|
324 |
+
|
325 |
+
class BasicLayer(nn.Module):
|
326 |
+
|
327 |
+
dim: int
|
328 |
+
input_resolution: int
|
329 |
+
depth: int
|
330 |
+
num_heads: int
|
331 |
+
window_size: int
|
332 |
+
mlp_ratio: float = 4.
|
333 |
+
qkv_bias: bool = True
|
334 |
+
qk_scale: Optional[float] = None
|
335 |
+
drop: float = 0.
|
336 |
+
attn_drop: float = 0.
|
337 |
+
drop_path: float = 0.
|
338 |
+
norm_layer: Callable = LayerNorm
|
339 |
+
downsample: Optional[Callable] = None
|
340 |
+
|
341 |
+
@nn.compact
|
342 |
+
def __call__(self, x, x_size, training):
|
343 |
+
for i in range(self.depth):
|
344 |
+
x = SwinTransformerBlock(
|
345 |
+
self.dim,
|
346 |
+
self.input_resolution,
|
347 |
+
self.num_heads,
|
348 |
+
self.window_size,
|
349 |
+
0 if (i % 2 == 0) else self.window_size // 2,
|
350 |
+
self.mlp_ratio,
|
351 |
+
self.qkv_bias,
|
352 |
+
self.qk_scale,
|
353 |
+
self.drop,
|
354 |
+
self.attn_drop,
|
355 |
+
self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
|
356 |
+
norm_layer=self.norm_layer
|
357 |
+
)(x, x_size, training)
|
358 |
+
|
359 |
+
if self.downsample is not None:
|
360 |
+
x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
|
361 |
+
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
class RSTB(nn.Module):
|
366 |
+
|
367 |
+
dim: int
|
368 |
+
input_resolution: int
|
369 |
+
depth: int
|
370 |
+
num_heads: int
|
371 |
+
window_size: int
|
372 |
+
mlp_ratio: float = 4.
|
373 |
+
qkv_bias: bool = True
|
374 |
+
qk_scale: Optional[float] = None
|
375 |
+
drop: float = 0.
|
376 |
+
attn_drop: float = 0.
|
377 |
+
drop_path: float = 0.
|
378 |
+
norm_layer: Callable = LayerNorm
|
379 |
+
downsample: Optional[Callable] = None
|
380 |
+
img_size: int = 224,
|
381 |
+
patch_size: int = 4,
|
382 |
+
resi_connection: str = '1conv'
|
383 |
+
|
384 |
+
@nn.compact
|
385 |
+
def __call__(self, x, x_size, training):
|
386 |
+
res = x
|
387 |
+
x = BasicLayer(dim=self.dim,
|
388 |
+
input_resolution=self.input_resolution,
|
389 |
+
depth=self.depth,
|
390 |
+
num_heads=self.num_heads,
|
391 |
+
window_size=self.window_size,
|
392 |
+
mlp_ratio=self.mlp_ratio,
|
393 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
394 |
+
drop=self.drop, attn_drop=self.attn_drop,
|
395 |
+
drop_path=self.drop_path,
|
396 |
+
norm_layer=self.norm_layer,
|
397 |
+
downsample=self.downsample)(x, x_size, training)
|
398 |
+
|
399 |
+
x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
|
400 |
+
|
401 |
+
# resi_connection == '1conv':
|
402 |
+
x = nn.Conv(self.dim, (3, 3))(x)
|
403 |
+
|
404 |
+
x = PatchEmbed()(x)
|
405 |
+
|
406 |
+
return x + res
|
407 |
+
|
408 |
+
|
409 |
+
class PatchEmbed(nn.Module):
|
410 |
+
norm_layer: Optional[Callable] = None
|
411 |
+
|
412 |
+
@nn.compact
|
413 |
+
def __call__(self, x):
|
414 |
+
x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
|
415 |
+
if self.norm_layer is not None:
|
416 |
+
x = self.norm_layer()(x)
|
417 |
+
return x
|
418 |
+
|
419 |
+
|
420 |
+
class PatchUnEmbed(nn.Module):
|
421 |
+
embed_dim: int = 96
|
422 |
+
|
423 |
+
@nn.compact
|
424 |
+
def __call__(self, x, x_size):
|
425 |
+
B, HW, C = x.shape
|
426 |
+
x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
class SwinIR(nn.Module):
|
431 |
+
r""" SwinIR JAX implementation
|
432 |
+
Args:
|
433 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
434 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
435 |
+
in_chans (int): Number of input image channels. Default: 3
|
436 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
437 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
438 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
439 |
+
window_size (int): Window size. Default: 7
|
440 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
441 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
442 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
443 |
+
drop_rate (float): Dropout rate. Default: 0
|
444 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
445 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
446 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
447 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
448 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
449 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
450 |
+
img_range: Image range. 1. or 25I think5.
|
451 |
+
"""
|
452 |
+
|
453 |
+
img_size: int = 48
|
454 |
+
patch_size: int = 1
|
455 |
+
in_chans: int = 3
|
456 |
+
embed_dim: int = 180
|
457 |
+
depths: tuple = (6, 6, 6, 6, 6, 6)
|
458 |
+
num_heads: tuple = (6, 6, 6, 6, 6, 6)
|
459 |
+
window_size: int = 8
|
460 |
+
mlp_ratio: float = 2.
|
461 |
+
qkv_bias: bool = True
|
462 |
+
qk_scale: Optional[float] = None
|
463 |
+
drop_rate: float = 0.
|
464 |
+
attn_drop_rate: float = 0.
|
465 |
+
drop_path_rate: float = 0.1
|
466 |
+
norm_layer: Callable = LayerNorm
|
467 |
+
ape: bool = False
|
468 |
+
patch_norm: bool = True
|
469 |
+
upscale: int = 2
|
470 |
+
img_range: float = 1.
|
471 |
+
num_feat: int = 64
|
472 |
+
|
473 |
+
def pad(self, x):
|
474 |
+
_, h, w, _ = x.shape
|
475 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
476 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
477 |
+
x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
|
478 |
+
return x
|
479 |
+
|
480 |
+
@nn.compact
|
481 |
+
def __call__(self, x, training):
|
482 |
+
_, h_before, w_before, _ = x.shape
|
483 |
+
x = self.pad(x)
|
484 |
+
_, h, w, _ = x.shape
|
485 |
+
patches_resolution = [self.img_size // self.patch_size] * 2
|
486 |
+
num_patches = patches_resolution[0] * patches_resolution[1]
|
487 |
+
|
488 |
+
# conv_first
|
489 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
490 |
+
res = x
|
491 |
+
|
492 |
+
# feature extraction
|
493 |
+
x_size = (h, w)
|
494 |
+
x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
|
495 |
+
|
496 |
+
if self.ape:
|
497 |
+
absolute_pos_embed = \
|
498 |
+
self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
|
499 |
+
x = x + absolute_pos_embed
|
500 |
+
|
501 |
+
x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
|
502 |
+
|
503 |
+
dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
|
504 |
+
for i_layer in range(len(self.depths)):
|
505 |
+
x = RSTB(
|
506 |
+
dim=self.embed_dim,
|
507 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
508 |
+
depth=self.depths[i_layer],
|
509 |
+
num_heads=self.num_heads[i_layer],
|
510 |
+
window_size=self.window_size,
|
511 |
+
mlp_ratio=self.mlp_ratio,
|
512 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
513 |
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
|
514 |
+
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
|
515 |
+
norm_layer=self.norm_layer,
|
516 |
+
downsample=None,
|
517 |
+
img_size=self.img_size,
|
518 |
+
patch_size=self.patch_size)(x, x_size, training)
|
519 |
+
|
520 |
+
x = self.norm_layer()(x) # B L C
|
521 |
+
x = PatchUnEmbed(self.embed_dim)(x, x_size)
|
522 |
+
|
523 |
+
# conv_after_body
|
524 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
525 |
+
x = x + res
|
526 |
+
|
527 |
+
# conv_before_upsample
|
528 |
+
x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
|
529 |
+
|
530 |
+
# revert padding
|
531 |
+
x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
|
532 |
+
return x
|
model/tail.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax.linen as nn
|
2 |
+
|
3 |
+
from .convnext import ConvNeXt
|
4 |
+
from .swin_ir import SwinIR
|
5 |
+
|
6 |
+
|
7 |
+
def build_tail(size: str):
|
8 |
+
""" Convenience function to build the three tails described in the paper. """
|
9 |
+
if size == 'air':
|
10 |
+
return lambda x, _: x
|
11 |
+
elif size == 'plus':
|
12 |
+
blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3
|
13 |
+
return ConvNeXt(blocks)
|
14 |
+
elif size == 'pro':
|
15 |
+
return SwinIR(depths=[7, 6], num_heads=[6, 6])
|
16 |
+
else:
|
17 |
+
raise NotImplementedError('size: ' + size)
|
18 |
+
|
model/thera.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import jax
|
4 |
+
from flax.core import unfreeze, freeze
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
from jaxtyping import Array, ArrayLike, PyTree
|
8 |
+
|
9 |
+
from .edsr import EDSR
|
10 |
+
from .rdn import RDN
|
11 |
+
from .hyper import Hypernetwork
|
12 |
+
from .tail import build_tail
|
13 |
+
from .init import uniform_between, linear_up
|
14 |
+
from utils import make_grid, interpolate_grid, repeat_vmap
|
15 |
+
|
16 |
+
|
17 |
+
class Thermal(nn.Module):
|
18 |
+
w0_scale: float = 1.
|
19 |
+
|
20 |
+
@nn.compact
|
21 |
+
def __call__(self, x: ArrayLike, t, norm, k) -> Array:
|
22 |
+
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
|
23 |
+
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
|
24 |
+
|
25 |
+
|
26 |
+
class TheraField(nn.Module):
|
27 |
+
dim_hidden: int
|
28 |
+
dim_out: int
|
29 |
+
w0: float = 1.
|
30 |
+
c: float = 6.
|
31 |
+
|
32 |
+
@nn.compact
|
33 |
+
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
|
34 |
+
# coordinate projection according to shared components ("first layer")
|
35 |
+
x = x @ components
|
36 |
+
|
37 |
+
# thermal activations
|
38 |
+
norm = jnp.linalg.norm(components, axis=-2)
|
39 |
+
x = Thermal(self.w0)(x, t, norm, k)
|
40 |
+
|
41 |
+
# linear projection from hidden to output space ("second layer")
|
42 |
+
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
|
43 |
+
dense_init_fn = uniform_between(-w_std, w_std)
|
44 |
+
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
|
45 |
+
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Thera:
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
hidden_dim: int,
|
54 |
+
out_dim: int,
|
55 |
+
backbone: nn.Module,
|
56 |
+
tail: nn.Module,
|
57 |
+
k_init: float = None,
|
58 |
+
components_init_scale: float = None
|
59 |
+
):
|
60 |
+
self.hidden_dim = hidden_dim
|
61 |
+
self.k_init = k_init
|
62 |
+
self.components_init_scale = components_init_scale
|
63 |
+
|
64 |
+
# single TheraField object whose `apply` method is used for all grid cells
|
65 |
+
self.field = TheraField(hidden_dim, out_dim)
|
66 |
+
|
67 |
+
# infer output size of the hypernetwork from a sample pass through the field;
|
68 |
+
# key doesnt matter as field params are only used for size inference
|
69 |
+
sample_params = self.field.init(jax.random.PRNGKey(0),
|
70 |
+
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
|
71 |
+
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
|
72 |
+
param_shapes = [p.shape for p in sample_params_flat]
|
73 |
+
|
74 |
+
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
|
75 |
+
|
76 |
+
def init(self, key, sample_source) -> PyTree:
|
77 |
+
keys = jax.random.split(key, 2)
|
78 |
+
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
|
79 |
+
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
|
80 |
+
|
81 |
+
params['params']['k'] = jnp.array(self.k_init)
|
82 |
+
params['params']['components'] = \
|
83 |
+
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
|
84 |
+
|
85 |
+
return freeze(params)
|
86 |
+
|
87 |
+
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
|
88 |
+
"""
|
89 |
+
Performs a forward pass through the hypernetwork to obtain an encoding.
|
90 |
+
"""
|
91 |
+
return self.hypernet.apply(
|
92 |
+
params, source, method=self.hypernet.get_encoding, **kwargs)
|
93 |
+
|
94 |
+
def apply_decoder(
|
95 |
+
self,
|
96 |
+
params: PyTree,
|
97 |
+
encoding: ArrayLike,
|
98 |
+
coords: ArrayLike,
|
99 |
+
t: ArrayLike,
|
100 |
+
return_jac: bool = False
|
101 |
+
) -> Array | tuple[Array, Array]:
|
102 |
+
"""
|
103 |
+
Performs a forward prediction through a grid of HxW Thera fields,
|
104 |
+
informed by `encoding`, at spatial and temporal coordinates
|
105 |
+
`coords` and `t`, respectively.
|
106 |
+
args:
|
107 |
+
params: Field parameters, shape (B, H, W, N)
|
108 |
+
encoding: Encoding tensor, shape (B, H, W, C)
|
109 |
+
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
|
110 |
+
t: Temporal coordinates, shape (B, 1)
|
111 |
+
"""
|
112 |
+
phi_params: PyTree = self.hypernet.apply(
|
113 |
+
params, encoding, coords, method=self.hypernet.get_params_at_coords)
|
114 |
+
|
115 |
+
# create local coordinate systems
|
116 |
+
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
|
117 |
+
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
|
118 |
+
interp_coords = interpolate_grid(coords, source_coords)
|
119 |
+
rel_coords = (coords - interp_coords)
|
120 |
+
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
|
121 |
+
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
|
122 |
+
|
123 |
+
# three maps over params, coords; one over t; dont map k and components
|
124 |
+
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
|
125 |
+
apply_field = repeat_vmap(self.field.apply, in_axes)
|
126 |
+
out = apply_field(phi_params, rel_coords, t, params['params']['k'],
|
127 |
+
params['params']['components'])
|
128 |
+
|
129 |
+
if return_jac:
|
130 |
+
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
|
131 |
+
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
|
132 |
+
params['params']['components'])
|
133 |
+
return out, jac
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
def apply(
|
138 |
+
self,
|
139 |
+
params: ArrayLike,
|
140 |
+
source: ArrayLike,
|
141 |
+
coords: ArrayLike,
|
142 |
+
t: ArrayLike,
|
143 |
+
return_jac: bool = False,
|
144 |
+
**kwargs
|
145 |
+
) -> Array:
|
146 |
+
"""
|
147 |
+
Performs a forward pass through the Thera model.
|
148 |
+
"""
|
149 |
+
encoding = self.apply_encoder(params, source, **kwargs)
|
150 |
+
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
|
151 |
+
return out
|
152 |
+
|
153 |
+
|
154 |
+
def build_thera(
|
155 |
+
out_dim: int,
|
156 |
+
backbone: str,
|
157 |
+
size: str,
|
158 |
+
k_init: float = None,
|
159 |
+
components_init_scale: float = None
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Convenience function for building the three Thera sizes described in the paper.
|
163 |
+
"""
|
164 |
+
hidden_dim = 32 if size == 'air' else 512
|
165 |
+
|
166 |
+
if backbone == 'edsr-baseline':
|
167 |
+
backbone_module = EDSR(None, num_blocks=16, num_feats=64)
|
168 |
+
elif backbone == 'rdn':
|
169 |
+
backbone_module = RDN()
|
170 |
+
else:
|
171 |
+
raise NotImplementedError(backbone)
|
172 |
+
|
173 |
+
tail_module = build_tail(size)
|
174 |
+
|
175 |
+
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)
|
requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
2 |
+
# torch-cpu is sufficient since we only use it for data loading
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
4 |
+
|
5 |
+
chex==0.1.7
|
6 |
+
ConfigArgParse==1.7
|
7 |
+
einops==0.6.1
|
8 |
+
flax==0.6.10
|
9 |
+
flaxmodels==0.1.3
|
10 |
+
jax==0.4.11
|
11 |
+
jaxlib==0.4.11+cuda11.cudnn86
|
12 |
+
jaxtyping==0.2.20
|
13 |
+
ml-dtypes==0.1.0
|
14 |
+
numpy==1.24.1
|
15 |
+
nvidia-cublas-cu11==11.11.3.6
|
16 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
17 |
+
nvidia-cuda-nvcc-cu11==11.8.89
|
18 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
19 |
+
nvidia-cudnn-cu11==8.9.2.26
|
20 |
+
nvidia-cufft-cu11==10.9.0.58
|
21 |
+
nvidia-cusolver-cu11==11.4.1.48
|
22 |
+
nvidia-cusparse-cu11==11.7.5.86
|
23 |
+
opt-einsum==3.3.0
|
24 |
+
optax==0.2.0
|
25 |
+
orbax-checkpoint==0.2.4
|
26 |
+
scipy==1.10.1
|
27 |
+
timm==0.9.6
|
28 |
+
torch==2.0.1+cpu
|
29 |
+
torchaudio==2.0.2+cpu
|
30 |
+
torchmetrics==1.2.0
|
31 |
+
torchvision==0.15.2+cpu
|
32 |
+
tqdm==4.65.0
|
33 |
+
transformers==4.46.3
|
34 |
+
Pillow==10.0.0
|
35 |
+
wandb
|
36 |
+
|
37 |
+
# gradio
|
38 |
+
gradio==4.44.1
|
39 |
+
gradio_imageslider==0.0.20
|
40 |
+
spaces
|
super_resolve.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from argparse import ArgumentParser, Namespace
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import jax
|
7 |
+
from jax import jit
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from model import build_thera
|
13 |
+
from utils import make_grid, interpolate_grid
|
14 |
+
|
15 |
+
MEAN = jnp.array([.4488, .4371, .4040])
|
16 |
+
VAR = jnp.array([.25, .25, .25])
|
17 |
+
PATCH_SIZE = 256
|
18 |
+
|
19 |
+
|
20 |
+
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
|
21 |
+
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
|
22 |
+
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
|
23 |
+
source_up = interpolate_grid(coords_nearest, source[None])
|
24 |
+
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
|
25 |
+
|
26 |
+
encoding = apply_encoder(params, source)
|
27 |
+
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
|
28 |
+
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
|
29 |
+
|
30 |
+
for h_min in range(0, coords.shape[1], PATCH_SIZE):
|
31 |
+
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
|
32 |
+
for w_min in range(0, coords.shape[2], PATCH_SIZE):
|
33 |
+
# apply decoder with one patch of coordinates
|
34 |
+
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
|
35 |
+
coords_patch = coords[:, h_min:h_max, w_min:w_max]
|
36 |
+
out_patch = apply_decoder(params, encoding, coords_patch, t)
|
37 |
+
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
|
38 |
+
|
39 |
+
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
|
40 |
+
out += source_up
|
41 |
+
return out
|
42 |
+
|
43 |
+
|
44 |
+
def process(source, model, params, target_shape, do_ensemble=True):
|
45 |
+
apply_encoder = jit(model.apply_encoder)
|
46 |
+
apply_decoder = jit(model.apply_decoder)
|
47 |
+
|
48 |
+
outs = []
|
49 |
+
for i_rot in range(4 if do_ensemble else 1):
|
50 |
+
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
|
51 |
+
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
|
52 |
+
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
|
53 |
+
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
|
54 |
+
|
55 |
+
out = jnp.stack(outs).mean(0).clip(0., 1.)
|
56 |
+
return jnp.rint(out[0] * 255).astype(jnp.uint8)
|
57 |
+
|
58 |
+
|
59 |
+
def main(args: Namespace):
|
60 |
+
source = np.asarray(Image.open(args.in_file)) / 255.
|
61 |
+
|
62 |
+
if args.scale is not None:
|
63 |
+
if args.size is not None:
|
64 |
+
raise ValueError('Cannot specify both size and scale')
|
65 |
+
target_shape = (
|
66 |
+
round(source.shape[0] * args.scale),
|
67 |
+
round(source.shape[1] * args.scale),
|
68 |
+
)
|
69 |
+
elif args.size is not None:
|
70 |
+
target_shape = args.size
|
71 |
+
else:
|
72 |
+
raise ValueError('Must specify either size or scale')
|
73 |
+
|
74 |
+
with open(args.checkpoint, 'rb') as fh:
|
75 |
+
check = pickle.load(fh)
|
76 |
+
params, backbone, size = check['model'], check['backbone'], check['size']
|
77 |
+
|
78 |
+
model = build_thera(3, backbone, size)
|
79 |
+
|
80 |
+
out = process(source, model, params, target_shape, not args.no_ensemble)
|
81 |
+
|
82 |
+
Image.fromarray(np.asarray(out)).save(args.out_file)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_args() -> Namespace:
|
86 |
+
parser = ArgumentParser()
|
87 |
+
parser.add_argument('in_file')
|
88 |
+
parser.add_argument('out_file')
|
89 |
+
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
|
90 |
+
parser.add_argument('--size', type=int, nargs=2,
|
91 |
+
help='Target size (h, w), mutually exclusive with --scale')
|
92 |
+
parser.add_argument('--checkpoint', help='Path to checkpoint file')
|
93 |
+
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
|
94 |
+
return parser.parse_args()
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
args = parse_args()
|
99 |
+
main(args)
|
utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def repeat_vmap(fun, in_axes=[0]):
|
8 |
+
for axes in in_axes:
|
9 |
+
fun = jax.vmap(fun, in_axes=axes)
|
10 |
+
return fun
|
11 |
+
|
12 |
+
|
13 |
+
def make_grid(patch_size: int | tuple[int, int]):
|
14 |
+
if isinstance(patch_size, int):
|
15 |
+
patch_size = (patch_size, patch_size)
|
16 |
+
offset_h, offset_w = 1 / (2 * np.array(patch_size))
|
17 |
+
space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
|
18 |
+
space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
|
19 |
+
return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
|
20 |
+
|
21 |
+
|
22 |
+
def interpolate_grid(coords, grid, order=0):
|
23 |
+
"""
|
24 |
+
args:
|
25 |
+
coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
|
26 |
+
grid: Tensor of shape (B, H', W', C)
|
27 |
+
returns:
|
28 |
+
Tensor of shape (B, H, W, C) with interpolated values
|
29 |
+
"""
|
30 |
+
# convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
|
31 |
+
# [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
|
32 |
+
coords = coords.transpose((0, 3, 1, 2))
|
33 |
+
coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
|
34 |
+
coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
|
35 |
+
map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
|
36 |
+
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
|