diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b079d6ee58e9091f0f4fd471369e754a478339c2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,167 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +Makefile +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +weights/ + +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache/ +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.DS_Store diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..d046d2971c1062ad4586700f605d996708a2d63e 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,209 @@ +import re +import os +import yaml +import tempfile +import subprocess +from pathlib import Path + +import torch import gradio as gr -def greet(name): - return "Hello " + name + "!!" +from src.flux.xflux_pipeline import XFluxPipeline + + +def list_dirs(path): + if path is None or path == "None" or path == "": + return + + if not os.path.exists(path): + path = os.path.dirname(path) + if not os.path.exists(path): + return + + if not os.path.isdir(path): + path = os.path.dirname(path) + + def natural_sort_key(s, regex=re.compile("([0-9]+)")): + return [ + int(text) if text.isdigit() else text.lower() for text in regex.split(s) + ] + + subdirs = [ + (item, os.path.join(path, item)) + for item in os.listdir(path) + if os.path.isdir(os.path.join(path, item)) + ] + subdirs = [ + filename + for item, filename in subdirs + if item[0] != "." and item not in ["__pycache__"] + ] + subdirs = sorted(subdirs, key=natural_sort_key) + if os.path.dirname(path) != "": + dirs = [os.path.dirname(path), path] + subdirs + else: + dirs = [path] + subdirs + + if os.sep == "\\": + dirs = [d.replace("\\", "/") for d in dirs] + for d in dirs: + yield d + +def list_train_data_dirs(): + current_train_data_dir = "." + return list(list_dirs(current_train_data_dir)) + +def update_config(d, u): + for k, v in u.items(): + if isinstance(v, dict): + d[k] = update_config(d.get(k, {}), v) + else: + # convert Gradio components to strings + if hasattr(v, 'value'): + d[k] = str(v.value) + else: + try: + d[k] = int(v) + except (TypeError, ValueError): + d[k] = str(v) + return d + +def start_lora_training( + data_dir: str, output_dir: str, lr: float, steps: int, rank: int + ): + inputs = { + "data_config": { + "img_dir": data_dir, + }, + "output_dir": output_dir, + "learning_rate": lr, + "rank": rank, + "max_train_steps": steps, + } + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print(f"Creating folder {output_dir} for the output checkpoint file...") + + script_path = Path(__file__).resolve() + config_path = script_path.parent / "train_configs" / "test_lora.yaml" + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + + config = update_config(config, inputs) + print("Config file is updated...", config) + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file: + yaml.dump(config, temp_file, default_flow_style=False) + tmp_config_path = temp_file.name + + command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path] + result = subprocess.run(command, check=True) + + # rRemove the temporary file after the command is run + Path(tmp_config_path).unlink() + + return result + + +def create_demo( + model_type: str, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + offload: bool = False, + ckpt_dir: str = "", + ): + xflux_pipeline = XFluxPipeline(model_type, device, offload) + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + + with gr.Blocks() as demo: + gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") + with gr.Tab("Inference"): + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") + + with gr.Accordion("Generation Options", open=False): + with gr.Row(): + width = gr.Slider(512, 2048, 1024, step=16, label="Width") + height = gr.Slider(512, 2048, 1024, step=16, label="Height") + neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") + with gr.Row(): + num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") + timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") + with gr.Row(): + guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) + true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) + seed = gr.Textbox(-1, label="Seed (-1 for random)") + + with gr.Accordion("ControlNet Options", open=False): + control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") + control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) + local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", + info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" + ) + controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) + + with gr.Accordion("LoRA Options", open=False): + lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) + lora_local_path = gr.Dropdown( + checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" + ) + + with gr.Accordion("IP Adapter Options", open=False): + image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) + ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") + neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) + neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") + ip_local_path = gr.Dropdown( + checkpoints, label="IP Adapter Checkpoint", + info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" + ) + generate_btn = gr.Button("Generate") + + with gr.Column(): + output_image = gr.Image(label="Generated Image") + download_btn = gr.File(label="Download full-resolution") + + inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path + ] + generate_btn.click( + fn=xflux_pipeline.gradio_generate, + inputs=inputs, + outputs=[output_image, download_btn], + ) + + with gr.Tab("LoRA Finetuning"): + data_dir = gr.Dropdown(list_train_data_dirs(), + label="Training images (directory containing the training images)" + ) + output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint") + + with gr.Accordion("Training Options", open=True): + lr = gr.Textbox(label="Learning Rate", value="1e-5") + steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps") + rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank") + + training_btn = gr.Button("Start training") + training_btn.click( + fn=start_lora_training, + inputs=[data_dir, output_dir, lr, steps, rank], + outputs=[], + ) + + + return demo + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Flux") + parser.add_argument("--name", type=str, default="flux-dev", help="Model name") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") + parser.add_argument("--share", action="store_true", help="Create a public link to your demo") + parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") + args = parser.parse_args() -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() + demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir) + demo.launch(share=args.share) diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f44ee92f35ba6903718dffb304342d09e2661f6 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,24 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + gpu: true + cuda: "12.1" + python_version: "3.11" + python_packages: + - "accelerate==0.30.1" + - "deepspeed==0.14.4" + - "einops==0.8.0" + - "transformers==4.43.3" + - "huggingface-hub==0.24.5" + - "einops==0.8.0" + - "pandas==2.2.2" + - "opencv-python==4.10.0.84" + - "pillow==10.4.0" + - "optimum-quanto==0.2.4" + - "sentencepiece==0.2.0" + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/image_datasets/canny_dataset.py b/image_datasets/canny_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..728c75672ef9e34188c32cb1d819b55b50b76799 --- /dev/null +++ b/image_datasets/canny_dataset.py @@ -0,0 +1,59 @@ +import os +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import json +import random +import cv2 + + +def canny_processor(image, low_threshold=100, high_threshold=200): + image = np.array(image) + image = cv2.Canny(image, low_threshold, high_threshold) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = Image.fromarray(image) + return canny_image + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +class CustomImageDataset(Dataset): + def __init__(self, img_dir, img_size=512): + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images.sort() + self.img_size = img_size + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + img = Image.open(self.images[idx]) + img = c_crop(img) + img = img.resize((self.img_size, self.img_size)) + hint = canny_processor(img) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1) + hint = torch.from_numpy((np.array(hint) / 127.5) - 1) + hint = hint.permute(2, 0, 1) + json_path = self.images[idx].split('.')[0] + '.json' + prompt = json.load(open(json_path))['caption'] + return img, hint, prompt + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def loader(train_batch_size, num_workers, **args): + dataset = CustomImageDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) diff --git a/image_datasets/dataset.py b/image_datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3166f62981660ed7985ed9bd164fa4f292e8c1ae --- /dev/null +++ b/image_datasets/dataset.py @@ -0,0 +1,92 @@ +import os +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import json +import random + +def image_resize(img, max_size=512): + w, h = img.size + if w >= h: + new_w = max_size + new_h = int((max_size / w) * h) + else: + new_h = max_size + new_w = int((max_size / h) * w) + return img.resize((new_w, new_h)) + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def crop_to_aspect_ratio(image, ratio="16:9"): + width, height = image.size + ratio_map = { + "16:9": (16, 9), + "4:3": (4, 3), + "1:1": (1, 1) + } + target_w, target_h = ratio_map[ratio] + target_ratio_value = target_w / target_h + + current_ratio = width / height + + if current_ratio > target_ratio_value: + new_width = int(height * target_ratio_value) + offset = (width - new_width) // 2 + crop_box = (offset, 0, offset + new_width, height) + else: + new_height = int(width / target_ratio_value) + offset = (height - new_height) // 2 + crop_box = (0, offset, width, offset + new_height) + + cropped_img = image.crop(crop_box) + return cropped_img + + +class CustomImageDataset(Dataset): + def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False): + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images.sort() + self.img_size = img_size + self.caption_type = caption_type + self.random_ratio = random_ratio + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + img = Image.open(self.images[idx]).convert('RGB') + if self.random_ratio: + ratio = random.choice(["16:9", "default", "1:1", "4:3"]) + if ratio != "default": + img = crop_to_aspect_ratio(img, ratio) + img = image_resize(img, self.img_size) + w, h = img.size + new_w = (w // 32) * 32 + new_h = (h // 32) * 32 + img = img.resize((new_w, new_h)) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1) + json_path = self.images[idx].split('.')[0] + '.' + self.caption_type + if self.caption_type == "json": + prompt = json.load(open(json_path))['caption'] + else: + prompt = open(json_path).read() + return img, prompt + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def loader(train_batch_size, num_workers, **args): + dataset = CustomImageDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ccc7ac5160b5ab35e7c027676ee8435d6aecbf --- /dev/null +++ b/main.py @@ -0,0 +1,184 @@ +import argparse +from PIL import Image +import os + +from src.flux.xflux_pipeline import XFluxPipeline + + +def create_argparser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", type=str, required=True, + help="The input text prompt" + ) + parser.add_argument( + "--neg_prompt", type=str, default="", + help="The input text negative prompt" + ) + parser.add_argument( + "--img_prompt", type=str, default=None, + help="Path to input image prompt" + ) + parser.add_argument( + "--neg_img_prompt", type=str, default=None, + help="Path to input negative image prompt" + ) + parser.add_argument( + "--ip_scale", type=float, default=1.0, + help="Strength of input image prompt" + ) + parser.add_argument( + "--neg_ip_scale", type=float, default=1.0, + help="Strength of negative input image prompt" + ) + parser.add_argument( + "--local_path", type=str, default=None, + help="Local path to the model checkpoint (Controlnet)" + ) + parser.add_argument( + "--repo_id", type=str, default=None, + help="A HuggingFace repo id to download model (Controlnet)" + ) + parser.add_argument( + "--name", type=str, default=None, + help="A filename to download from HuggingFace" + ) + parser.add_argument( + "--ip_repo_id", type=str, default=None, + help="A HuggingFace repo id to download model (IP-Adapter)" + ) + parser.add_argument( + "--ip_name", type=str, default=None, + help="A IP-Adapter filename to download from HuggingFace" + ) + parser.add_argument( + "--ip_local_path", type=str, default=None, + help="Local path to the model checkpoint (IP-Adapter)" + ) + parser.add_argument( + "--lora_repo_id", type=str, default=None, + help="A HuggingFace repo id to download model (LoRA)" + ) + parser.add_argument( + "--lora_name", type=str, default=None, + help="A LoRA filename to download from HuggingFace" + ) + parser.add_argument( + "--lora_local_path", type=str, default=None, + help="Local path to the model checkpoint (Controlnet)" + ) + parser.add_argument( + "--device", type=str, default="cuda", + help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)" + ) + parser.add_argument( + "--offload", action='store_true', help="Offload model to CPU when not in use" + ) + parser.add_argument( + "--use_ip", action='store_true', help="Load IP model" + ) + parser.add_argument( + "--use_lora", action='store_true', help="Load Lora model" + ) + parser.add_argument( + "--use_controlnet", action='store_true', help="Load Controlnet model" + ) + parser.add_argument( + "--num_images_per_prompt", type=int, default=1, + help="The number of images to generate per prompt" + ) + parser.add_argument( + "--image", type=str, default=None, help="Path to image" + ) + parser.add_argument( + "--lora_weight", type=float, default=0.9, help="Lora model strength (from 0 to 1.0)" + ) + parser.add_argument( + "--control_weight", type=float, default=0.8, help="Controlnet model strength (from 0 to 1.0)" + ) + parser.add_argument( + "--control_type", type=str, default="canny", + choices=("canny", "openpose", "depth", "zoe", "hed", "hough", "tile"), + help="Name of controlnet condition, example: canny" + ) + parser.add_argument( + "--model_type", type=str, default="flux-dev", + choices=("flux-dev", "flux-dev-fp8", "flux-schnell"), + help="Model type to use (flux-dev, flux-dev-fp8, flux-schnell)" + ) + parser.add_argument( + "--width", type=int, default=1024, help="The width for generated image" + ) + parser.add_argument( + "--height", type=int, default=1024, help="The height for generated image" + ) + parser.add_argument( + "--num_steps", type=int, default=25, help="The num_steps for diffusion process" + ) + parser.add_argument( + "--guidance", type=float, default=4, help="The guidance for diffusion process" + ) + parser.add_argument( + "--seed", type=int, default=123456789, help="A seed for reproducible inference" + ) + parser.add_argument( + "--true_gs", type=float, default=3.5, help="true guidance" + ) + parser.add_argument( + "--timestep_to_start_cfg", type=int, default=5, help="timestep to start true guidance" + ) + parser.add_argument( + "--save_path", type=str, default='results', help="Path to save" + ) + return parser + + +def main(args): + if args.image: + image = Image.open(args.image) + else: + image = None + + xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload) + if args.use_ip: + print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name) + xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name) + if args.use_lora: + print('load lora:', args.lora_local_path, args.lora_repo_id, args.lora_name) + xflux_pipeline.set_lora(args.lora_local_path, args.lora_repo_id, args.lora_name, args.lora_weight) + if args.use_controlnet: + print('load controlnet:', args.local_path, args.repo_id, args.name) + xflux_pipeline.set_controlnet(args.control_type, args.local_path, args.repo_id, args.name) + + image_prompt = Image.open(args.img_prompt) if args.img_prompt else None + neg_image_prompt = Image.open(args.neg_img_prompt) if args.neg_img_prompt else None + + for _ in range(args.num_images_per_prompt): + result = xflux_pipeline( + prompt=args.prompt, + controlnet_image=image, + width=args.width, + height=args.height, + guidance=args.guidance, + num_steps=args.num_steps, + seed=args.seed, + true_gs=args.true_gs, + control_weight=args.control_weight, + neg_prompt=args.neg_prompt, + timestep_to_start_cfg=args.timestep_to_start_cfg, + image_prompt=image_prompt, + neg_image_prompt=neg_image_prompt, + ip_scale=args.ip_scale, + neg_ip_scale=args.neg_ip_scale, + ) + if not os.path.exists(args.save_path): + os.mkdir(args.save_path) + ind = len(os.listdir(args.save_path)) + result.save(os.path.join(args.save_path, f"result_{ind}.png")) + args.seed = args.seed + 1 + + +if __name__ == "__main__": + args = create_argparser().parse_args() + main(args) diff --git a/models_licence/LICENSE-FLUX1-dev b/models_licence/LICENSE-FLUX1-dev new file mode 100644 index 0000000000000000000000000000000000000000..d91cf0bcef46f7ab49551034ccf3bea6b765f8d6 --- /dev/null +++ b/models_licence/LICENSE-FLUX1-dev @@ -0,0 +1,42 @@ +FLUX.1 [dev] Non-Commercial License +Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] text-to-image AI model and its elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI model made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). +By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. + 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings: + a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. + b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be. + c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose. + d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters. + e. “you” or “your” means the individual or entity entering into this License with Company. + 2. License Grant. + a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf. + b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai. + c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License. + d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model. + 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions: + a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License; + b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”): +“The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc. +IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.” + c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and + d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions. + e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing. + 4. Restrictions. You will not, and will not permit, assist or cause any third party to + a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; + b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model; + c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or + d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License. + e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model; + f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. + 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. + 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. + 7. INDEMNIFICATION + +You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties. + 8. Termination; Survival. + a. This License will automatically terminate upon any breach by you of the terms of this License. + b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. + c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated. + d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11. + 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. + 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators. + 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company. \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5cd8e5737556f82cef8fb88692456589934bd4 --- /dev/null +++ b/predict.py @@ -0,0 +1,134 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +from cog import BasePredictor, Input, Path +import os +import time +import torch +import subprocess +from PIL import Image +from typing import List +from image_datasets.canny_dataset import canny_processor, c_crop +from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors + +OUTPUT_DIR = "controlnet_results" +MODEL_CACHE = "checkpoints" +CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors" +T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar" +CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar" +HF_TOKEN = "hf_..." # Your HuggingFace token + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + +def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + model = load_flow_model(name, device="cpu" if offload else device) + ae = load_ae(name, device="cpu" if offload else device) + controlnet = load_controlnet(name, device).to(torch.bfloat16) + return model, ae, t5, clip, controlnet + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + t1 = time.time() + os.system(f"huggingface-cli login --token {HF_TOKEN}") + name = "flux-dev" + self.offload = False + checkpoint = "controlnet.safetensors" + + print("Checking ControlNet weights") + checkpoint = "controlnet.safetensors" + if not os.path.exists(checkpoint): + os.system(f"wget {CONTROLNET_URL}") + print("Checking T5 weights") + if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"): + download_weights(T5_URL, MODEL_CACHE) + print("Checking CLIP weights") + if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"): + download_weights(CLIP_URL, MODEL_CACHE) + + self.is_schnell = False + device = "cuda" + self.torch_device = torch.device(device) + model, ae, t5, clip, controlnet = get_models( + name, + device=self.torch_device, + offload=self.offload, + is_schnell=self.is_schnell, + ) + self.ae = ae + self.t5 = t5 + self.clip = clip + self.controlnet = controlnet + self.model = model.to(self.torch_device) + if '.safetensors' in checkpoint: + checkpoint1 = load_safetensors(checkpoint) + else: + checkpoint1 = torch.load(checkpoint, map_location='cpu') + + controlnet.load_state_dict(checkpoint1, strict=False) + t2 = time.time() + print(f"Setup time: {t2 - t1}") + + def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512): + image = Image.open(image_path) + image = c_crop(image) + image = image.resize((width, height)) + image = canny_processor(image) + return image + + def predict( + self, + prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"), + image: Path = Input(description="Input image", default=None), + num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28), + cfg: float = Input(description="CFG", ge=0, le=10, default=3.5), + seed: int = Input(description="Random seed", default=None) + ) -> List[Path]: + """Run a single prediction on the model""" + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + # clean output dir + output_dir = "controlnet_results" + os.system(f"rm -rf {output_dir}") + + input_image = str(image) + img = Image.open(input_image) + width, height = img.size + # Resize input image if it's too large + max_image_size = 1536 + scale = min(max_image_size / width, max_image_size / height, 1) + if scale < 1: + width = int(width * scale) + height = int(height * scale) + print(f"Scaling image down to {width}x{height}") + img = img.resize((width, height), resample=Image.Resampling.LANCZOS) + input_image = "/tmp/resized_image.png" + img.save(input_image) + + subprocess.check_call( + ["python3", "main.py", + "--local_path", "controlnet.safetensors", + "--image", input_image, + "--use_controlnet", + "--control_type", "canny", + "--prompt", prompt, + "--width", str(width), + "--height", str(height), + "--num_steps", str(num_inference_steps), + "--guidance", str(cfg), + "--seed", str(seed) + ], close_fds=False) + + # Find the first file that begins with "controlnet_result_" + for file in os.listdir(output_dir): + if file.startswith("controlnet_result_"): + return [Path(os.path.join(output_dir, file))] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f7c9f2299356262a6c4da02d53a24cdfa045162 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +accelerate==0.30.1 +deepspeed==0.14.4 +einops==0.8.0 +transformers==4.43.3 +huggingface-hub==0.24.5 +optimum-quanto +datasets +omegaconf +diffusers +sentencepiece +opencv-python +matplotlib +onnxruntime +torchvision +timm diff --git a/src/flux/__init__.py b/src/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/src/flux/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/src/flux/__main__.py b/src/flux/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cf0fd2444d4cda4053fa74dad3371556b886e5 --- /dev/null +++ b/src/flux/__main__.py @@ -0,0 +1,4 @@ +from .cli import app + +if __name__ == "__main__": + app() diff --git a/src/flux/annotator/canny/__init__.py b/src/flux/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/src/flux/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/src/flux/annotator/ckpts/ckpts.txt b/src/flux/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/src/flux/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/src/flux/annotator/dwpose/__init__.py b/src/flux/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6e172d05c9de3f1cdd61e330ad8d6dde46dfdd --- /dev/null +++ b/src/flux/annotator/dwpose/__init__.py @@ -0,0 +1,68 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_bodypose(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self, device): + + self.pose_estimation = Wholebody(device) + + def __call__(self, oriImg): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W) diff --git a/src/flux/annotator/dwpose/onnxdet.py b/src/flux/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/src/flux/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/src/flux/annotator/dwpose/onnxpose.py b/src/flux/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/src/flux/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/src/flux/annotator/dwpose/util.py b/src/flux/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/src/flux/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/src/flux/annotator/dwpose/wholebody.py b/src/flux/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..d73f19d61c238c47cf7de98d01385b2150a5361f --- /dev/null +++ b/src/flux/annotator/dwpose/wholebody.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from huggingface_hub import hf_hub_download +from .onnxdet import inference_detector +from .onnxpose import inference_pose + + +class Wholebody: + def __init__(self, device="cuda:0"): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") + onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/src/flux/annotator/hed/__init__.py b/src/flux/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70d11c9e62133149d38091a597a1b6691ff8f1b6 --- /dev/null +++ b/src/flux/annotator/hed/__init__.py @@ -0,0 +1,95 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from huggingface_hub import hf_hub_download +from einops import rearrange +from ...annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/src/flux/annotator/midas/LICENSE b/src/flux/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/src/flux/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/midas/__init__.py b/src/flux/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/src/flux/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/src/flux/annotator/midas/api.py b/src/flux/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6226a39d80de978162a7238cec1c4d4a64bacbe9 --- /dev/null +++ b/src/flux/annotator/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from huggingface_hub import hf_hub_download + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from ...annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt") + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/src/flux/annotator/midas/midas/__init__.py b/src/flux/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/midas/midas/base_model.py b/src/flux/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/src/flux/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/src/flux/annotator/midas/midas/blocks.py b/src/flux/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/src/flux/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/src/flux/annotator/midas/midas/dpt_depth.py b/src/flux/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/src/flux/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/src/flux/annotator/midas/midas/midas_net.py b/src/flux/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/src/flux/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/src/flux/annotator/midas/midas/midas_net_custom.py b/src/flux/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/src/flux/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/src/flux/annotator/midas/midas/transforms.py b/src/flux/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/midas/midas/vit.py b/src/flux/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/src/flux/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/midas/utils.py b/src/flux/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/src/flux/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/src/flux/annotator/mlsd/LICENSE b/src/flux/annotator/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/src/flux/annotator/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/src/flux/annotator/mlsd/__init__.py b/src/flux/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5028aef051a1e67caae1f0d23a0b0dbca883a7f8 --- /dev/null +++ b/src/flux/annotator/mlsd/__init__.py @@ -0,0 +1,40 @@ +# MLSD Line Detection +# From https://github.com/navervision/mlsd +# Apache-2.0 license + +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from huggingface_hub import hf_hub_download +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from ...annotator.util import annotator_ckpts_path + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth") + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py b/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py b/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/src/flux/annotator/mlsd/utils.py b/src/flux/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..848a9fd7f9ff9d909f18c5d3ff55786c5a4b547a --- /dev/null +++ b/src/flux/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to("cuda:4") + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/src/flux/annotator/tile/__init__.py b/src/flux/annotator/tile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96899289d6e04796140cd7eff9c08e5f693af02 --- /dev/null +++ b/src/flux/annotator/tile/__init__.py @@ -0,0 +1,26 @@ +import random +import cv2 +from .guided_filter import FastGuidedFilter + + +class TileDetector: + # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0 + def __init__(self): + pass + + def __call__(self, image): + blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0] + radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] + eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0] + scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0] + + ksize = int(blur_strength) + if ksize % 2 == 0: + ksize += 1 + + if random.random() > 0.5: + image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2) + if random.random() > 0.5: + filter = FastGuidedFilter(image, radius, eps, scale_factor) + image = filter.filter(image) + return image diff --git a/src/flux/annotator/tile/guided_filter.py b/src/flux/annotator/tile/guided_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7172a5e144672eea26551ef75f70b90a2f96d6 --- /dev/null +++ b/src/flux/annotator/tile/guided_filter.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +## @package guided_filter.core.filters +# +# Implementation of guided filter. +# * GuidedFilter: Original guided filter. +# * FastGuidedFilter: Fast version of the guided filter. +# @author tody +# @date 2015/08/26 + +import numpy as np +import cv2 + +## Convert image into float32 type. +def to32F(img): + if img.dtype == np.float32: + return img + return (1.0 / 255.0) * np.float32(img) + +## Convert image into uint8 type. +def to8U(img): + if img.dtype == np.uint8: + return img + return np.clip(np.uint8(255.0 * img), 0, 255) + +## Return if the input image is gray or not. +def _isGray(I): + return len(I.shape) == 2 + + +## Return down sampled image. +# @param scale (w/s, h/s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _downSample(I, scale=4, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) + + +## Return up sampled image. +# @param scale (w*s, h*s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _upSample(I, scale=2, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + +## Fast guide filter. +class FastGuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + # @param scale Down sampled scale. + def __init__(self, I, radius=5, epsilon=0.4, scale=4): + I_32F = to32F(I) + self._I = I_32F + h, w = I.shape[:2] + + I_sub = _downSample(I_32F, scale) + + self._I_sub = I_sub + radius = int(radius / scale) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + shape_original = p.shape[:2] + + p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) + + if _isGray(p_sub): + return self._filterGray(p_sub, shape_original) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) + return to8U(q) + + def _filterGray(self, p_sub, shape_original): + ab_sub = self._guided_filter._computeCoefficients(p_sub) + ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] + return self._guided_filter._computeOutput(ab, self._I) + + +## Guide filter. +class GuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + I_32F = to32F(I) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return to8U(self._guided_filter.filter(p)) + + +## Common parts of guided filter. +# +# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. +# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, +# GuidedFilterCommon.filter computes filtered image for color and gray. +class GuidedFilterCommon: + def __init__(self, guided_filter): + self._guided_filter = guided_filter + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + if _isGray(p_32F): + return self._filterGray(p_32F) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) + return q + + def _filterGray(self, p): + ab = self._guided_filter._computeCoefficients(p) + return self._guided_filter._computeOutput(ab, self._guided_filter._I) + + +## Guided filter for gray guidance image. +class GuidedFilterGray: + # @param I Input gray guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + self._I_mean = cv2.blur(I, (r, r)) + I_mean_sq = cv2.blur(I ** 2, (r, r)) + self._I_var = I_mean_sq - self._I_mean ** 2 + + def _computeCoefficients(self, p): + r = self._radius + p_mean = cv2.blur(p, (r, r)) + p_cov = p_mean - self._I_mean * p_mean + a = p_cov / (self._I_var + self._epsilon) + b = p_mean - a * self._I_mean + a_mean = cv2.blur(a, (r, r)) + b_mean = cv2.blur(b, (r, r)) + return a_mean, b_mean + + def _computeOutput(self, ab, I): + a_mean, b_mean = ab + return a_mean * I + b_mean + + +## Guided filter for color guidance image. +class GuidedFilterColor: + # @param I Input color guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.2): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + eps = self._epsilon + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + self._Ir_mean = cv2.blur(Ir, (r, r)) + self._Ig_mean = cv2.blur(Ig, (r, r)) + self._Ib_mean = cv2.blur(Ib, (r, r)) + + Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps + Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean + Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean + Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps + Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean + Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps + + Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var + Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var + Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var + Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var + Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var + Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var + + I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var + Irr_inv /= I_cov + Irg_inv /= I_cov + Irb_inv /= I_cov + Igg_inv /= I_cov + Igb_inv /= I_cov + Ibb_inv /= I_cov + + self._Irr_inv = Irr_inv + self._Irg_inv = Irg_inv + self._Irb_inv = Irb_inv + self._Igg_inv = Igg_inv + self._Igb_inv = Igb_inv + self._Ibb_inv = Ibb_inv + + def _computeCoefficients(self, p): + r = self._radius + I = self._I + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + p_mean = cv2.blur(p, (r, r)) + + Ipr_mean = cv2.blur(Ir * p, (r, r)) + Ipg_mean = cv2.blur(Ig * p, (r, r)) + Ipb_mean = cv2.blur(Ib * p, (r, r)) + + Ipr_cov = Ipr_mean - self._Ir_mean * p_mean + Ipg_cov = Ipg_mean - self._Ig_mean * p_mean + Ipb_cov = Ipb_mean - self._Ib_mean * p_mean + + ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov + ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov + ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov + b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean + + ar_mean = cv2.blur(ar, (r, r)) + ag_mean = cv2.blur(ag, (r, r)) + ab_mean = cv2.blur(ab, (r, r)) + b_mean = cv2.blur(b, (r, r)) + + return ar_mean, ag_mean, ab_mean, b_mean + + def _computeOutput(self, ab, I): + ar_mean, ag_mean, ab_mean, b_mean = ab + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + q = (ar_mean * Ir + + ag_mean * Ig + + ab_mean * Ib + + b_mean) + + return q diff --git a/src/flux/annotator/util.py b/src/flux/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/src/flux/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/src/flux/annotator/zoe/LICENSE b/src/flux/annotator/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/src/flux/annotator/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/flux/annotator/zoe/__init__.py b/src/flux/annotator/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7628090932e35bdd71d041069ae62f6a731f60d4 --- /dev/null +++ b/src/flux/annotator/zoe/__init__.py @@ -0,0 +1,48 @@ +# ZoeDepth +# https://github.com/isl-org/ZoeDepth + +import os +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.utils.config import get_config +from ...annotator.util import annotator_ckpts_path +from huggingface_hub import hf_hub_download + + +class ZoeDetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt") + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(model_path)['model'], strict=False) + model = model.cuda() + model.device = 'cuda' + model.eval() + self.model = model + + def __call__(self, input_image): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return depth_image diff --git a/src/flux/annotator/zoe/zoedepth/data/__init__.py b/src/flux/annotator/zoe/zoedepth/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/data/data_mono.py b/src/flux/annotator/zoe/zoedepth/data/data_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..80a8486f239a35331df553f490e213f9bf71e735 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/data_mono.py @@ -0,0 +1,573 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee + +import itertools +import os +import random + +import numpy as np +import cv2 +import torch +import torch.nn as nn +import torch.utils.data.distributed +from zoedepth.utils.easydict import EasyDict as edict +from PIL import Image, ImageOps +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +from zoedepth.utils.config import change_dataset + +from .ddad import get_ddad_loader +from .diml_indoor_test import get_diml_indoor_loader +from .diml_outdoor_test import get_diml_outdoor_loader +from .diode import get_diode_loader +from .hypersim import get_hypersim_loader +from .ibims import get_ibims_loader +from .sun_rgbd_loader import get_sunrgbd_loader +from .vkitti import get_vkitti_loader +from .vkitti2 import get_vkitti2_loader + +from .preprocess import CropParams, get_white_border, get_black_border + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def preprocessing_transforms(mode, **kwargs): + return transforms.Compose([ + ToTensor(mode=mode, **kwargs) + ]) + + +class DepthDataLoader(object): + def __init__(self, config, mode, device='cpu', transform=None, **kwargs): + """ + Data loader for depth datasets + + Args: + config (dict): Config dictionary. Refer to utils/config.py + mode (str): "train" or "online_eval" + device (str, optional): Device to load the data on. Defaults to 'cpu'. + transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. + """ + + self.config = config + + if config.dataset == 'ibims': + self.data = get_ibims_loader(config, batch_size=1, num_workers=1) + return + + if config.dataset == 'sunrgbd': + self.data = get_sunrgbd_loader( + data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_indoor': + self.data = get_diml_indoor_loader( + data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_outdoor': + self.data = get_diml_outdoor_loader( + data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) + return + + if "diode" in config.dataset: + self.data = get_diode_loader( + config[config.dataset+"_root"], batch_size=1, num_workers=1) + return + + if config.dataset == 'hypersim_test': + self.data = get_hypersim_loader( + config.hypersim_test_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti': + self.data = get_vkitti_loader( + config.vkitti_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti2': + self.data = get_vkitti2_loader( + config.vkitti2_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'ddad': + self.data = get_ddad_loader(config.ddad_root, resize_shape=( + 352, 1216), batch_size=1, num_workers=1) + return + + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + + if transform is None: + transform = preprocessing_transforms(mode, size=img_size) + + if mode == 'train': + + Dataset = DataLoadPreprocess + self.training_samples = Dataset( + config, mode, transform=transform, device=device) + + if config.distributed: + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_samples) + else: + self.train_sampler = None + + self.data = DataLoader(self.training_samples, + batch_size=config.batch_size, + shuffle=(self.train_sampler is None), + num_workers=config.workers, + pin_memory=True, + persistent_workers=True, + # prefetch_factor=2, + sampler=self.train_sampler) + + elif mode == 'online_eval': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + if config.distributed: # redundant. here only for readability and to be more explicit + # Give whole test set to all processes (and report evaluation only on one) regardless + self.eval_sampler = None + else: + self.eval_sampler = None + self.data = DataLoader(self.testing_samples, 1, + shuffle=kwargs.get("shuffle_test", False), + num_workers=1, + pin_memory=False, + sampler=self.eval_sampler) + + elif mode == 'test': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + self.data = DataLoader(self.testing_samples, + 1, shuffle=False, num_workers=1) + + else: + print( + 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) + + +def repetitive_roundrobin(*iterables): + """ + cycles through iterables but sample wise + first yield first sample from first iterable then first sample from second iterable and so on + then second sample from first iterable then second sample from second iterable and so on + + If one iterable is shorter than the others, it is repeated until all iterables are exhausted + repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E + """ + # Repetitive roundrobin + iterables_ = [iter(it) for it in iterables] + exhausted = [False] * len(iterables) + while not all(exhausted): + for i, it in enumerate(iterables_): + try: + yield next(it) + except StopIteration: + exhausted[i] = True + iterables_[i] = itertools.cycle(iterables[i]) + # First elements may get repeated if one iterable is shorter than the others + yield next(iterables_[i]) + + +class RepetitiveRoundRobinDataLoader(object): + def __init__(self, *dataloaders): + self.dataloaders = dataloaders + + def __iter__(self): + return repetitive_roundrobin(*self.dataloaders) + + def __len__(self): + # First samples get repeated, thats why the plus one + return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) + + +class MixedNYUKITTI(object): + def __init__(self, config, mode, device='cpu', **kwargs): + config = edict(config) + config.workers = config.workers // 2 + self.config = config + nyu_conf = change_dataset(edict(config), 'nyu') + kitti_conf = change_dataset(edict(config), 'kitti') + + # make nyu default for testing + self.config = config = nyu_conf + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + if mode == 'train': + nyu_loader = DepthDataLoader( + nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + kitti_loader = DepthDataLoader( + kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + # It has been changed to repetitive roundrobin + self.data = RepetitiveRoundRobinDataLoader( + nyu_loader, kitti_loader) + else: + self.data = DepthDataLoader(nyu_conf, mode, device=device).data + + +def remove_leading_slash(s): + if s[0] == '/' or s[0] == '\\': + return s[1:] + return s + + +class CachedReader: + def __init__(self, shared_dict=None): + if shared_dict: + self._cache = shared_dict + else: + self._cache = {} + + def open(self, fpath): + im = self._cache.get(fpath, None) + if im is None: + im = self._cache[fpath] = Image.open(fpath) + return im + + +class ImReader: + def __init__(self): + pass + + # @cache + def open(self, fpath): + return Image.open(fpath) + + +class DataLoadPreprocess(Dataset): + def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): + self.config = config + if mode == 'online_eval': + with open(config.filenames_file_eval, 'r') as f: + self.filenames = f.readlines() + else: + with open(config.filenames_file, 'r') as f: + self.filenames = f.readlines() + + self.mode = mode + self.transform = transform + self.to_tensor = ToTensor(mode) + self.is_for_online_eval = is_for_online_eval + if config.use_shared_dict: + self.reader = CachedReader(config.shared_dict) + else: + self.reader = ImReader() + + def postprocess(self, sample): + return sample + + def __getitem__(self, idx): + sample_path = self.filenames[idx] + focal = float(sample_path.split()[2]) + sample = {} + + if self.mode == 'train': + if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[3])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[4])) + else: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[0])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[1])) + + image = self.reader.open(image_path) + depth_gt = self.reader.open(depth_path) + w, h = image.size + + if self.config.do_kb_crop: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth_gt = depth_gt.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + + # Avoid blank boundaries due to pixel registration? + # Train images have white border. Test images have black border. + if self.config.dataset == 'nyu' and self.config.avoid_boundary: + # print("Avoiding Blank Boundaries!") + # We just crop and pad again with reflect padding to original size + # original_size = image.size + crop_params = get_white_border(np.array(image, dtype=np.uint8)) + image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + + # Use reflect padding to fill the blank + image = np.array(image) + image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') + image = Image.fromarray(image) + + depth_gt = np.array(depth_gt) + depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) + depth_gt = Image.fromarray(depth_gt) + + + if self.config.do_random_rotate and (self.config.aug): + random_angle = (random.random() - 0.5) * 2 * self.config.degree + image = self.rotate_image(image, random_angle) + depth_gt = self.rotate_image( + depth_gt, random_angle, flag=Image.NEAREST) + + image = np.asarray(image, dtype=np.float32) / 255.0 + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + if self.config.aug and (self.config.random_crop): + image, depth_gt = self.random_crop( + image, depth_gt, self.config.input_height, self.config.input_width) + + if self.config.aug and self.config.random_translate: + # print("Random Translation!") + image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) + + image, depth_gt = self.train_preprocess(image, depth_gt) + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample = {'image': image, 'depth': depth_gt, 'focal': focal, + 'mask': mask, **sample} + + else: + if self.mode == 'online_eval': + data_path = self.config.data_path_eval + else: + data_path = self.config.data_path + + image_path = os.path.join( + data_path, remove_leading_slash(sample_path.split()[0])) + image = np.asarray(self.reader.open(image_path), + dtype=np.float32) / 255.0 + + if self.mode == 'online_eval': + gt_path = self.config.gt_path_eval + depth_path = os.path.join( + gt_path, remove_leading_slash(sample_path.split()[1])) + has_valid_depth = False + try: + depth_gt = self.reader.open(depth_path) + has_valid_depth = True + except IOError: + depth_gt = False + # print('Missing gt for {}'.format(image_path)) + + if has_valid_depth: + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + mask = np.logical_and( + depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] + else: + mask = False + + if self.config.do_kb_crop: + height = image.shape[0] + width = image.shape[1] + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + image = image[top_margin:top_margin + 352, + left_margin:left_margin + 1216, :] + if self.mode == 'online_eval' and has_valid_depth: + depth_gt = depth_gt[top_margin:top_margin + + 352, left_margin:left_margin + 1216, :] + + if self.mode == 'online_eval': + sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], + 'mask': mask} + else: + sample = {'image': image, 'focal': focal} + + if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample['mask'] = mask + + if self.transform: + sample = self.transform(sample) + + sample = self.postprocess(sample) + sample['dataset'] = self.config.dataset + sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} + + return sample + + def rotate_image(self, image, angle, flag=Image.BILINEAR): + result = image.rotate(angle, resample=flag) + return result + + def random_crop(self, img, depth, height, width): + assert img.shape[0] >= height + assert img.shape[1] >= width + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + x = random.randint(0, img.shape[1] - width) + y = random.randint(0, img.shape[0] - height) + img = img[y:y + height, x:x + width, :] + depth = depth[y:y + height, x:x + width, :] + + return img, depth + + def random_translate(self, img, depth, max_t=20): + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + p = self.config.translate_prob + do_translate = random.random() + if do_translate > p: + return img, depth + x = random.randint(-max_t, max_t) + y = random.randint(-max_t, max_t) + M = np.float32([[1, 0, x], [0, 1, y]]) + # print(img.shape, depth.shape) + img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) + depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) + depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it + # print("after", img.shape, depth.shape) + return img, depth + + def train_preprocess(self, image, depth_gt): + if self.config.aug: + # Random flipping + do_flip = random.random() + if do_flip > 0.5: + image = (image[:, ::-1, :]).copy() + depth_gt = (depth_gt[:, ::-1, :]).copy() + + # Random gamma, brightness, color augmentation + do_augment = random.random() + if do_augment > 0.5: + image = self.augment_image(image) + + return image, depth_gt + + def augment_image(self, image): + # gamma augmentation + gamma = random.uniform(0.9, 1.1) + image_aug = image ** gamma + + # brightness augmentation + if self.config.dataset == 'nyu': + brightness = random.uniform(0.75, 1.25) + else: + brightness = random.uniform(0.9, 1.1) + image_aug = image_aug * brightness + + # color augmentation + colors = np.random.uniform(0.9, 1.1, size=3) + white = np.ones((image.shape[0], image.shape[1])) + color_image = np.stack([white * colors[i] for i in range(3)], axis=2) + image_aug *= color_image + image_aug = np.clip(image_aug, 0, 1) + + return image_aug + + def __len__(self): + return len(self.filenames) + + +class ToTensor(object): + def __init__(self, mode, do_normalize=False, size=None): + self.mode = mode + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() + self.size = size + if size is not None: + self.resize = transforms.Resize(size=size) + else: + self.resize = nn.Identity() + + def __call__(self, sample): + image, focal = sample['image'], sample['focal'] + image = self.to_tensor(image) + image = self.normalize(image) + image = self.resize(image) + + if self.mode == 'test': + return {'image': image, 'focal': focal} + + depth = sample['depth'] + if self.mode == 'train': + depth = self.to_tensor(depth) + return {**sample, 'image': image, 'depth': depth, 'focal': focal} + else: + has_valid_depth = sample['has_valid_depth'] + image = self.resize(image) + return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img diff --git a/src/flux/annotator/zoe/zoedepth/data/ddad.py b/src/flux/annotator/zoe/zoedepth/data/ddad.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd0492bdec767685d3a21992b4a26e62d002d97 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/ddad.py @@ -0,0 +1,117 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self, resize_shape): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(resize_shape) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "ddad"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DDAD(Dataset): + def __init__(self, data_dir_root, resize_shape): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join(data_dir_root, '*.png')) + self.depth_files = [r.replace("_rgb.png", "_depth.npy") + for r in self.image_files] + self.transform = ToTensor(resize_shape) + + def __getitem__(self, idx): + + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs): + dataset = DDAD(data_dir_root, resize_shape) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py b/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f720ad9aefaee78ef4ec363dfef0f82ace850a6d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diml_indoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Indoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "LR", '*', 'color', '*.png')) + self.depth_files = [r.replace("color", "depth_filled").replace( + "_c.png", "_depth_filled.png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Indoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR") +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR") diff --git a/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py b/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8670b48f5febafb819dac22848ad79ccb5dd5ae4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py @@ -0,0 +1,114 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Outdoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "*", 'outleft', '*.png')) + self.depth_files = [r.replace("outleft", "depthmap") + for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth, dataset="diml_outdoor") + + # return sample + return self.transform(sample) + + def __len__(self): + return len(self.image_files) + + +def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Outdoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR") +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR") diff --git a/src/flux/annotator/zoe/zoedepth/data/diode.py b/src/flux/annotator/zoe/zoedepth/data/diode.py new file mode 100644 index 0000000000000000000000000000000000000000..1510c87116b8f70ce2e1428873a8e4da042bee23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diode.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(480) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diode"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIODE(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /scene_#/scan_#/*.png + self.image_files = glob.glob( + os.path.join(data_dir_root, '*', '*', '*.png')) + self.depth_files = [r.replace(".png", "_depth.npy") + for r in self.image_files] + self.depth_mask_files = [ + r.replace(".png", "_depth_mask.npy") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + depth_mask_path = self.depth_mask_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # in meters + valid = np.load(depth_mask_path) # binary + + # depth[depth > 8] = -1 + # depth = depth[..., None] + + sample = dict(image=image, depth=depth, valid=valid) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diode_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIODE(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diode_loader(data_dir_root="datasets/diode/val/outdoor") diff --git a/src/flux/annotator/zoe/zoedepth/data/hypersim.py b/src/flux/annotator/zoe/zoedepth/data/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..4334198971830200f72ea2910d03f4c7d6a43334 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/hypersim.py @@ -0,0 +1,138 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import glob +import os + +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +def hypersim_distance_to_depth(npyDistance): + intWidth, intHeight, fltFocal = 1024, 768, 886.81 + + npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape( + 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] + npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, + intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] + npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) + npyImageplane = np.concatenate( + [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) + + npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal + return npyDepth + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "hypersim"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class HyperSim(Dataset): + def __init__(self, data_dir_root): + # image paths are of the form //images/scene_cam_#_final_preview/*.tonemap.jpg + # depth paths are of the form //images/scene_cam_#_final_preview/*.depth_meters.hdf5 + self.image_files = glob.glob(os.path.join( + data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg')) + self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace( + ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + + # depth from hdf5 + depth_fd = h5py.File(depth_path, "r") + # in meters (Euclidean distance) + distance_meters = np.array(depth_fd['dataset']) + depth = hypersim_distance_to_depth( + distance_meters) # in meters (planar depth) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs): + dataset = HyperSim(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/ibims.py b/src/flux/annotator/zoe/zoedepth/data/ibims.py new file mode 100644 index 0000000000000000000000000000000000000000..b66abfabcf4cfc617d4a60ec818780c3548d9920 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/ibims.py @@ -0,0 +1,81 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms as T + + +class iBims(Dataset): + def __init__(self, config): + root_folder = config.ibims_root + with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f: + imglist = f.read().split() + + samples = [] + for basename in imglist: + img_path = os.path.join(root_folder, 'rgb', basename + ".png") + depth_path = os.path.join(root_folder, 'depth', basename + ".png") + valid_mask_path = os.path.join( + root_folder, 'mask_invalid', basename+".png") + transp_mask_path = os.path.join( + root_folder, 'mask_transp', basename+".png") + + samples.append( + (img_path, depth_path, valid_mask_path, transp_mask_path)) + + self.samples = samples + # self.normalize = T.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __getitem__(self, idx): + img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx] + + img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype=np.uint16).astype('float')*50.0/65535 + + mask_valid = np.asarray(Image.open(valid_mask_path)) + mask_transp = np.asarray(Image.open(transp_mask_path)) + + # depth = depth * mask_valid * mask_transp + depth = np.where(mask_valid * mask_transp, depth, -1) + + img = torch.from_numpy(img).permute(2, 0, 1) + img = self.normalize(img) + depth = torch.from_numpy(depth).unsqueeze(0) + return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims') + + def __len__(self): + return len(self.samples) + + +def get_ibims_loader(config, batch_size=1, **kwargs): + dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs) + return dataloader diff --git a/src/flux/annotator/zoe/zoedepth/data/preprocess.py b/src/flux/annotator/zoe/zoedepth/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cc309dc823ae6efd7cda8db9eb37130dc5499 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/preprocess.py @@ -0,0 +1,154 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +from dataclasses import dataclass +from typing import Tuple, List + +# dataclass to store the crop parameters +@dataclass +class CropParams: + top: int + bottom: int + left: int + right: int + + + +def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams: + gray_image = np.mean(rgb_image, axis=channel_axis) + h, w = gray_image.shape + + + def num_value_pixels(arr): + return np.sum(np.abs(arr - value) < level_diff_threshold) + + def is_above_tolerance(arr, total_pixels): + return (num_value_pixels(arr) / total_pixels) > tolerance + + # Crop top border until number of value pixels become below tolerance + top = min_border + while is_above_tolerance(gray_image[top, :], w) and top < h-1: + top += 1 + if top > cut_off: + break + + # Crop bottom border until number of value pixels become below tolerance + bottom = h - min_border + while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0: + bottom -= 1 + if h - bottom > cut_off: + break + + # Crop left border until number of value pixels become below tolerance + left = min_border + while is_above_tolerance(gray_image[:, left], h) and left < w-1: + left += 1 + if left > cut_off: + break + + # Crop right border until number of value pixels become below tolerance + right = w - min_border + while is_above_tolerance(gray_image[:, right], h) and right > 0: + right -= 1 + if w - right > cut_off: + break + + + return CropParams(top, bottom, left, right) + + +def get_white_border(rgb_image, value=255, **kwargs) -> CropParams: + """Crops the white border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + Returns: + Crop parameters. + """ + if value == 255: + # assert range of values in rgb image is [0, 255] + assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]." + assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]." + elif value == 1: + # assert range of values in rgb image is [0, 1] + assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]." + + return get_border_params(rgb_image, value=value, **kwargs) + +def get_black_border(rgb_image, **kwargs) -> CropParams: + """Crops the black border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + + Returns: + Crop parameters. + """ + + return get_border_params(rgb_image, value=0, **kwargs) + +def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray: + """Crops the image according to the crop parameters. + + Args: + image: RGB or depth image, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped image. + """ + return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right] + +def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]: + """Crops the images according to the crop parameters. + + Args: + images: RGB or depth images, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped images. + """ + return tuple(crop_image(image, crop_params) for image in images) + +def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]: + """Crops the white and black border of the RGB and depth images. + + Args: + rgb: RGB image, shape (H, W, 3). This image is used to determine the border. + other_images: The other images to crop according to the border of the RGB image. + Returns: + Cropped RGB and other images. + """ + # crop black border + crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params) + + # crop white border + crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(*cropped_images, crop_params=crop_params) + + return cropped_images + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py b/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2bdb9aefe68ca4439f41eff3bba722c49fb976 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py @@ -0,0 +1,106 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "sunrgbd"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class SunRGBD(Dataset): + def __init__(self, data_dir_root): + # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze() + # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs] + # self.all_test = [os.path.join(data_dir_root, t) for t in all_test] + import glob + self.image_files = glob.glob( + os.path.join(data_dir_root, 'rgb', 'rgb', '*')) + self.depth_files = [ + r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0 + depth[depth > 8] = -1 + depth = depth[..., None] + return self.transform(dict(image=image, depth=depth)) + + def __len__(self): + return len(self.image_files) + + +def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs): + dataset = SunRGBD(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/transforms.py b/src/flux/annotator/zoe/zoedepth/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..374416dff24fb4fd55598f3946d6d6b091ddefc9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/transforms.py @@ -0,0 +1,481 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import math +import random + +import cv2 +import numpy as np + + +class RandomFliplr(object): + """Horizontal flip of the sample with given probability. + """ + + def __init__(self, probability=0.5): + """Init. + + Args: + probability (float, optional): Flip probability. Defaults to 0.5. + """ + self.__probability = probability + + def __call__(self, sample): + prob = random.random() + + if prob < self.__probability: + for k, v in sample.items(): + if len(v.shape) >= 2: + sample[k] = np.fliplr(v).copy() + + return sample + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class RandomCrop(object): + """Get a random crop of the sample with the given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_if_needed=False, + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): output width + height (int): output height + resize_if_needed (bool, optional): If True, sample might be upsampled to ensure + that a crop of size (width, height) is possbile. Defaults to False. + """ + self.__size = (height, width) + self.__resize_if_needed = resize_if_needed + self.__image_interpolation_method = image_interpolation_method + + def __call__(self, sample): + + shape = sample["disparity"].shape + + if self.__size[0] > shape[0] or self.__size[1] > shape[1]: + if self.__resize_if_needed: + shape = apply_min_size( + sample, self.__size, self.__image_interpolation_method + ) + else: + raise Exception( + "Output size {} bigger than input size {}.".format( + self.__size, shape + ) + ) + + offset = ( + np.random.randint(shape[0] - self.__size[0] + 1), + np.random.randint(shape[1] - self.__size[1] + 1), + ) + + for k, v in sample.items(): + if k == "code" or k == "basis": + continue + + if len(sample[k].shape) >= 2: + sample[k] = v[ + offset[0]: offset[0] + self.__size[0], + offset[1]: offset[1] + self.__size[1], + ] + + return sample + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + letter_box=False, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + self.__letter_box = letter_box + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def make_letter_box(self, sample): + top = bottom = (self.__height - sample.shape[0]) // 2 + left = right = (self.__width - sample.shape[1]) // 2 + sample = cv2.copyMakeBorder( + sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0) + return sample + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__letter_box: + sample["image"] = self.make_letter_box(sample["image"]) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["disparity"] = self.make_letter_box( + sample["disparity"]) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, + height), interpolation=cv2.INTER_NEAREST + ) + + if self.__letter_box: + sample["depth"] = self.make_letter_box(sample["depth"]) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["mask"] = self.make_letter_box(sample["mask"]) + + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class ResizeFixed(object): + def __init__(self, size): + self.__size = size + + def __call__(self, sample): + sample["image"] = cv2.resize( + sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], self.__size[::- + 1], interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + self.__size[::-1], + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class Rescale(object): + """Rescale target values to the interval [0, max_val]. + If input is constant, values are set to max_val / 2. + """ + + def __init__(self, max_val=1.0, use_mask=True): + """Init. + + Args: + max_val (float, optional): Max output value. Defaults to 1.0. + use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True. + """ + self.__max_val = max_val + self.__use_mask = use_mask + + def __call__(self, sample): + disp = sample["disparity"] + + if self.__use_mask: + mask = sample["mask"] + else: + mask = np.ones_like(disp, dtype=np.bool) + + if np.sum(mask) == 0: + return sample + + min_val = np.min(disp[mask]) + max_val = np.max(disp[mask]) + + if max_val > min_val: + sample["disparity"][mask] = ( + (disp[mask] - min_val) / (max_val - min_val) * self.__max_val + ) + else: + sample["disparity"][mask] = np.ones_like( + disp[mask]) * self.__max_val / 2.0 + + return sample + + +# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class DepthToDisparity(object): + """Convert depth to disparity. Removes depth from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "depth" in sample + + sample["mask"][sample["depth"] < self.__eps] = False + + sample["disparity"] = np.zeros_like(sample["depth"]) + sample["disparity"][sample["depth"] >= self.__eps] = ( + 1.0 / sample["depth"][sample["depth"] >= self.__eps] + ) + + del sample["depth"] + + return sample + + +class DisparityToDepth(object): + """Convert disparity to depth. Removes disparity from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "disparity" in sample + + disp = np.abs(sample["disparity"]) + sample["mask"][disp < self.__eps] = False + + # print(sample["disparity"]) + # print(sample["mask"].sum()) + # exit() + + sample["depth"] = np.zeros_like(disp) + sample["depth"][disp >= self.__eps] = ( + 1.0 / disp[disp >= self.__eps] + ) + + del sample["disparity"] + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/data/vkitti.py b/src/flux/annotator/zoe/zoedepth/data/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..72a2e5a8346f6e630ede0e28d6959725af8d7e72 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/vkitti.py @@ -0,0 +1,151 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os + +from PIL import Image +import numpy as np +import cv2 + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True): + import glob + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "test_color", '*.png')) + self.depth_files = [r.replace("test_color", "test_depth") + for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) + print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + + if self.do_kb_crop and False: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/src/flux/annotator/zoe/zoedepth/data/vkitti2.py b/src/flux/annotator/zoe/zoedepth/data/vkitti2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcfb0414b7f3f21859f30ae34bd71689516a3e7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/vkitti2.py @@ -0,0 +1,187 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI2(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True, split="test"): + import glob + + # image paths are of the form /rgb///frames//Camera<0,1>/rgb_{}.jpg + self.image_files = glob.glob(os.path.join( + data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True) + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + # If train test split is not created, then create one. + # Split is such that 8% of the frames from each scene are used for testing. + if not os.path.exists(os.path.join(data_dir_root, "train.txt")): + import random + scenes = set([os.path.basename(os.path.dirname( + os.path.dirname(os.path.dirname(f)))) for f in self.image_files]) + train_files = [] + test_files = [] + for scene in scenes: + scene_files = [f for f in self.image_files if os.path.basename( + os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene] + random.shuffle(scene_files) + train_files.extend(scene_files[:int(len(scene_files) * 0.92)]) + test_files.extend(scene_files[int(len(scene_files) * 0.92):]) + with open(os.path.join(data_dir_root, "train.txt"), "w") as f: + f.write("\n".join(train_files)) + with open(os.path.join(data_dir_root, "test.txt"), "w") as f: + f.write("\n".join(test_files)) + + if split == "train": + with open(os.path.join(data_dir_root, "train.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + elif split == "test": + with open(os.path.join(data_dir_root, "test.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + # depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m + depth = Image.fromarray(depth) + # print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + if self.do_kb_crop: + if idx == 0: + print("Using KB input crop") + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = np.asarray(depth, dtype=np.float32) / 1. + depth[depth > 80] = -1 + + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI2(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti2_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/src/flux/annotator/zoe/zoedepth/models/__init__.py b/src/flux/annotator/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py b/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..ee660bc93d44c28efe8d8c674e715ea2ecb4c183 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,379 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + print("Params passed to Resize transform:") + print("\twidth: ", width) + print("\theight: ", height) + print("\tresize_target: ", resize_target) + print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + print("\tensure_multiple_of: ", ensure_multiple_of) + print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a13c80028de3d297de4a3f09cee1b20759acc006 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore @@ -0,0 +1,110 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +*.png +*.pfm +*.jpg +*.jpeg +*.pt \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..466bc94ba3128ea9cbe4bde82bd2fd1fc9daa8af --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile @@ -0,0 +1,29 @@ +# enables cuda support in docker +FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 + +# install python 3.6, pip and requirements for opencv-python +# (see https://github.com/NVIDIA/nvidia-docker/issues/864) +RUN apt-get update && apt-get -y install \ + python3 \ + python3-pip \ + libsm6 \ + libxext6 \ + libxrender-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# install python dependencies +RUN pip3 install --upgrade pip +RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm + +# copy inference code +WORKDIR /opt/MiDaS +COPY ./midas ./midas +COPY ./*.py ./ + +# download model weights so the docker image can be used offline +RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } +RUN python3 run.py --model_type dpt_hybrid; exit 0 + +# entrypoint (dont forget to mount input and output directories) +CMD python3 run.py --model_type dpt_hybrid diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9568ea71c755b6938ee5482ba9f09be722e75943 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9abe5693b9e0de56b7d20728f4d0e6333c5822d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml @@ -0,0 +1,16 @@ +name: midas-py310 +channels: + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python=3.10.8 + - pytorch::pytorch=1.13.0 + - torchvision=0.14.0 + - pip=22.3.1 + - numpy=1.23.4 + - pip: + - opencv-python==4.6.0.66 + - imutils==0.5.4 + - timm==0.6.12 + - einops==0.6.0 \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d638be5151c4e305daff0c47d1ea3fc8066377d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e616dfd4026f448f9e22d35c6ad8b0028732acb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3129d09cb43a7c79b23916236991fabbedb78f55 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..45c18f7f0bfe40c0db373e8a94716867705f5827 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md @@ -0,0 +1,70 @@ +## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation + +### Accuracy + +* Old small model - ResNet50 default-decoder 384x384 +* New small model - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on iOS / Android + +**Frames Per Second** (the higher - the better): + +| Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI | +|---|---|---|---|---|---|---| +| Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 | +| New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 | +| SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** | + +N/A - run-time error (no data available) + + +#### Models: + +* Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor) + +* New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape) + +#### Frameworks for training and conversions: +``` +pip install torch==1.6.0 torchvision==0.7.0 +pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0 +git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow +``` + +#### SoC - OS - Library: + +* iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly +* OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly + + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fbe357549c64ae2966d5c3013a9179427b7b396 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore @@ -0,0 +1,13 @@ +*.iml +.gradle +/local.properties +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +.DS_Store +/build +/captures +.externalNativeBuild + +/.gradle/ +/.idea/ \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md new file mode 100644 index 0000000000000000000000000000000000000000..72014bdfa2cd701a6453debbc8e53fcc15c0a5dc --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md @@ -0,0 +1,414 @@ +# TensorFlow Lite Android image classification example + +This document walks through the code of a simple Android mobile application that +demonstrates +[image classification](https://www.tensorflow.org/lite/models/image_classification/overview) +using the device camera. + +## Explore the code + +We're now going to walk through the most important parts of the sample code. + +### Get camera input + +This mobile application gets the camera input using the functions defined in the +file +[`CameraActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java). +This file depends on +[`AndroidManifest.xml`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/AndroidManifest.xml) +to set the camera orientation. + +`CameraActivity` also contains code to capture user preferences from the UI and +make them available to other classes via convenience methods. + +```java +model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); +device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); +numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); +``` + +### Classifier + +This Image Classification Android reference app demonstrates two implementation +solutions, +[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), +and +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support) +that creates the custom inference pipleline using the +[TensorFlow Lite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support). + +Both solutions implement the file `Classifier.java` (see +[the one in lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java) +and +[the one in lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java)) +that contains most of the complex logic for processing the camera input and +running inference. + +Two subclasses of the `Classifier` exist, as in `ClassifierFloatMobileNet.java` +and `ClassifierQuantizedMobileNet.java`, which contain settings for both +floating point and +[quantized](https://www.tensorflow.org/lite/performance/post_training_quantization) +models. + +The `Classifier` class implements a static method, `create`, which is used to +instantiate the appropriate subclass based on the supplied model type (quantized +vs floating point). + +#### Using the TensorFlow Lite Task Library + +Inference can be done using just a few lines of code with the +[`ImageClassifier`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier) +in the TensorFlow Lite Task Library. + +##### Load model and create ImageClassifier + +`ImageClassifier` expects a model populated with the +[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label +file. See the +[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) +for more details. + +`ImageClassifierOptions` allows manipulation on various inference options, such +as setting the maximum number of top scored results to return using +`setMaxResults(MAX_RESULTS)`, and setting the score threshold using +`setScoreThreshold(scoreThreshold)`. + +```java +// Create the ImageClassifier instance. +ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); +imageClassifier = ImageClassifier.createFromFileAndOptions(activity, + getModelPath(), options); +``` + +`ImageClassifier` currently does not support configuring delegates and +multithread, but those are on our roadmap. Please stay tuned! + +##### Run inference + +`ImageClassifier` contains builtin logic to preprocess the input image, such as +rotating and resizing an image. Processing options can be configured through +`ImageProcessingOptions`. In the following example, input images are rotated to +the up-right angle and cropped to the center as the model expects a square input +(`224x224`). See the +[Java doc of `ImageClassifier`](https://github.com/tensorflow/tflite-support/blob/195b574f0aa9856c618b3f1ad87bd185cddeb657/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java#L22) +for more details about how the underlying image processing is performed. + +```java +TensorImage inputImage = TensorImage.fromBitmap(bitmap); +int width = bitmap.getWidth(); +int height = bitmap.getHeight(); +int cropSize = min(width, height); +ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + +List results = imageClassifier.classify(inputImage, + imageOptions); +``` + +The output of `ImageClassifier` is a list of `Classifications` instance, where +each `Classifications` element is a single head classification result. All the +demo models are single head models, therefore, `results` only contains one +`Classifications` object. Use `Classifications.getCategories()` to get a list of +top-k categories as specified with `MAX_RESULTS`. Each `Category` object +contains the srting label and the score of that category. + +To match the implementation of +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), +`results` is converted into `List` in the method, +`getRecognitions`. + +#### Using the TensorFlow Lite Support Library + +##### Load model and create interpreter + +To perform inference, we need to load a model file and instantiate an +`Interpreter`. This happens in the constructor of the `Classifier` class, along +with loading the list of class labels. Information about the device type and +number of threads is used to configure the `Interpreter` via the +`Interpreter.Options` instance passed into its constructor. Note that if a GPU, +DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a +[`Delegate`](https://www.tensorflow.org/lite/performance/delegates) can be used +to take full advantage of these hardware. + +Please note that there are performance edge cases and developers are adviced to +test with a representative set of devices prior to production. + +```java +protected Classifier(Activity activity, Device device, int numThreads) throws + IOException { + tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + labels = FileUtil.loadLabels(activity, getLabelPath()); +... +``` + +For Android devices, we recommend pre-loading and memory mapping the model file +to offer faster load times and reduce the dirty pages in memory. The method +`FileUtil.loadMappedFile` does this, returning a `MappedByteBuffer` containing +the model. + +The `MappedByteBuffer` is passed into the `Interpreter` constructor, along with +an `Interpreter.Options` object. This object can be used to configure the +interpreter, for example by setting the number of threads (`.setNumThreads(1)`) +or enabling [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks) +(`.addDelegate(nnApiDelegate)`). + +##### Pre-process bitmap image + +Next in the `Classifier` constructor, we take the input camera bitmap image, +convert it to a `TensorImage` format for efficient processing and pre-process +it. The steps are shown in the private 'loadImage' method: + +```java +/** Loads input image, and applys preprocessing. */ +private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + image.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight()); + int numRoration = sensorOrientation / 90; + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR)) + .add(new Rot90Op(numRoration)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); +} +``` + +The pre-processing is largely the same for quantized and float models with one +exception: Normalization. + +In `ClassifierFloatMobileNet`, the normalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 127.5f; +private static final float IMAGE_STD = 127.5f; +``` + +In `ClassifierQuantizedMobileNet`, normalization is not required. Thus the +nomalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 0.0f; +private static final float IMAGE_STD = 1.0f; +``` + +##### Allocate output object + +Initiate the output `TensorBuffer` for the output of the model. + +```java +/** Output probability TensorBuffer. */ +private final TensorBuffer outputProbabilityBuffer; + +//... +// Get the array size for the output buffer from the TensorFlow Lite model file +int probabilityTensorIndex = 0; +int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001} +DataType probabilityDataType = + tflite.getOutputTensor(probabilityTensorIndex).dataType(); + +// Creates the output tensor and its processor. +outputProbabilityBuffer = + TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + +// Creates the post processor for the output probability. +probabilityProcessor = + new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); +``` + +For quantized models, we need to de-quantize the prediction with the NormalizeOp +(as they are all essentially linear transformation). For float model, +de-quantize is not required. But to uniform the API, de-quantize is added to +float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more +specific, + +In `ClassifierQuantizedMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 255.0f; +``` + +In `ClassifierFloatMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 1.0f; +``` + +##### Run inference + +Inference is performed using the following in `Classifier` class: + +```java +tflite.run(inputImageBuffer.getBuffer(), + outputProbabilityBuffer.getBuffer().rewind()); +``` + +##### Recognize image + +Rather than call `run` directly, the method `recognizeImage` is used. It accepts +a bitmap and sensor orientation, runs inference, and returns a sorted `List` of +`Recognition` instances, each corresponding to a label. The method will return a +number of results bounded by `MAX_RESULTS`, which is 3 by default. + +`Recognition` is a simple class that contains information about a specific +recognition result, including its `title` and `confidence`. Using the +post-processing normalization method specified, the confidence is converted to +between 0 and 1 of a given class being represented by the image. + +```java +/** Gets the label to probability map. */ +Map labeledProbability = + new TensorLabel(labels, + probabilityProcessor.process(outputProbabilityBuffer)) + .getMapWithFloatValue(); +``` + +A `PriorityQueue` is used for sorting. + +```java +/** Gets the top-k results. */ +private static List getTopKProbability( + Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of + // the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), + entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; +} +``` + +### Display results + +The classifier is invoked and inference results are displayed by the +`processImage()` function in +[`ClassifierActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java). + +`ClassifierActivity` is a subclass of `CameraActivity` that contains method +implementations that render the camera image, run classification, and display +the results. The method `processImage()` runs classification on a background +thread as fast as possible, rendering information on the UI thread to avoid +blocking inference and creating latency. + +```java +@Override +protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, + previewHeight); + final int imageSizeX = classifier.getImageSizeX(); + final int imageSizeY = classifier.getImageSizeY(); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + final List results = + classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + showResultsInBottomSheet(results); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(imageSizeX + "x" + imageSizeY); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); +} +``` + +Another important role of `ClassifierActivity` is to determine user preferences +(by interrogating `CameraActivity`), and instantiate the appropriately +configured `Classifier` subclass. This happens when the video feed begins (via +`onPreviewSizeChosen()`) and when options are changed in the UI (via +`onInferenceConfigurationChanged()`). + +```java +private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU && model == Model.QUANTIZED) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, "GPU does not yet supported quantized models.", + Toast.LENGTH_LONG) + .show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, + device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException e) { + LOGGER.e(e, "Failed to create classifier."); + } +} +``` diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md new file mode 100644 index 0000000000000000000000000000000000000000..faf415eb27ccc1a62357718d1e0a9b8c746de4e8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md @@ -0,0 +1,21 @@ +# MiDaS on Android smartphone by using TensorFlow-lite (TFLite) + + +* Either use Android Studio for compilation. + +* Or use ready to install apk-file: + * Or use URL: https://i.diawi.com/CVb8a9 + * Or use QR-code: + +Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP + +![CVb8a9](https://user-images.githubusercontent.com/4096485/97727213-38552500-1ae1-11eb-8b76-4ea11216f76d.png) + +---- + +To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets` + + +---- + +Original repository: https://github.com/isl-org/MiDaS diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ae74c6780c277d75fedfb7511ff51f69941b48b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore @@ -0,0 +1,3 @@ +/build + +/build/ \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..94e9886a55c7d54f71b424bb246c849dd6bd795d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle @@ -0,0 +1,56 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + defaultConfig { + applicationId "org.tensorflow.lite.examples.classification" + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } + flavorDimensions "tfliteInference" + productFlavors { + // The TFLite inference is built using the TFLite Support library. + support { + dimension "tfliteInference" + } + // The TFLite inference is built using the TFLite Task library. + taskApi { + dimension "tfliteInference" + } + } + +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + supportImplementation project(":lib_support") + taskApiImplementation project(":lib_task_api") + implementation 'androidx.appcompat:appcompat:1.0.0' + implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0' + implementation 'com.google.android.material:material:1.0.0' + + androidTestImplementation 'androidx.test.ext:junit:1.1.1' + androidTestImplementation 'com.google.truth:truth:1.0.1' + androidTestImplementation 'androidx.test:runner:1.2.0' + androidTestImplementation 'androidx.test:rules:1.1.0' +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdfad31f9b3e694817025d8b8f2ca0b40aa436bb --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt @@ -0,0 +1,3 @@ +red_fox 0.79403335 +kit_fox 0.16753247 +grey_fox 0.03619214 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt new file mode 100644 index 0000000000000000000000000000000000000000..3668ce54df0d1e57e31c58281d6085b83928f991 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt @@ -0,0 +1,3 @@ +red_fox 0.85 +kit_fox 0.13 +grey_fox 0.02 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..3653d8799092492ebbb16c7c956eb50e3d404aa4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0194132890aae659c2a70d33106306ed665b22e8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.util.Log; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.rule.ActivityTestRule; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Scanner; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +/** Golden test for Image Classification Reference app. */ +@RunWith(AndroidJUnit4.class) +public class ClassifierTest { + + @Rule + public ActivityTestRule rule = + new ActivityTestRule<>(ClassifierActivity.class); + + private static final String[] INPUTS = {"fox.jpg"}; + private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"}; + private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"}; + + @Test + public void classificationResultsShouldNotChange() throws IOException { + ClassifierActivity activity = rule.getActivity(); + Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1); + for (int i = 0; i < INPUTS.length; i++) { + String imageFileName = INPUTS[i]; + String goldenOutputFileName; + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // This is a temporary workaround to set different golden rest results as the preprocessing + // of lib_support and lib_task_api are different. Will merge them once the above TODO is + // resolved. + if (Classifier.TAG.equals("ClassifierWithSupport")) { + goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i]; + } else { + goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i]; + } + Bitmap input = loadImage(imageFileName); + List goldenOutput = loadRecognitions(goldenOutputFileName); + + List result = classifier.recognizeImage(input, 0); + Iterator goldenOutputIterator = goldenOutput.iterator(); + + for (Recognition actual : result) { + Assert.assertTrue(goldenOutputIterator.hasNext()); + Recognition expected = goldenOutputIterator.next(); + assertThat(actual.getTitle()).isEqualTo(expected.getTitle()); + assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence()); + } + } + } + + private static Bitmap loadImage(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load image from assets"); + } + return BitmapFactory.decodeStream(inputStream); + } + + private static List loadRecognitions(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load probability results from assets"); + } + Scanner scanner = new Scanner(inputStream); + List result = new ArrayList<>(); + while (scanner.hasNext()) { + String category = scanner.next(); + category = category.replace('_', ' '); + if (!scanner.hasNextFloat()) { + break; + } + float probability = scanner.nextFloat(); + Recognition recognition = new Recognition(null, category, probability, null); + result.add(recognition); + } + return result; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a414d5176a117262dce56c2220e6b71791287de --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..d1eb26c862c04bf573ecc4eb127e7460f0b100fc --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -0,0 +1,717 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.Manifest; +import android.app.Fragment; +import android.content.Context; +import android.content.pm.PackageManager; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.RectF; +import android.hardware.Camera; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Build; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Trace; +import androidx.annotation.NonNull; +import androidx.annotation.UiThread; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewTreeObserver; +import android.view.WindowManager; +import android.widget.AdapterView; +import android.widget.ImageView; +import android.widget.LinearLayout; +import android.widget.Spinner; +import android.widget.TextView; +import android.widget.Toast; +import com.google.android.material.bottomsheet.BottomSheetBehavior; +import java.nio.ByteBuffer; +import java.util.List; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public abstract class CameraActivity extends AppCompatActivity + implements OnImageAvailableListener, + Camera.PreviewCallback, + View.OnClickListener, + AdapterView.OnItemSelectedListener { + private static final Logger LOGGER = new Logger(); + + private static final int PERMISSIONS_REQUEST = 1; + + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; + protected int previewWidth = 0; + protected int previewHeight = 0; + private Handler handler; + private HandlerThread handlerThread; + private boolean useCamera2API; + private boolean isProcessingFrame = false; + private byte[][] yuvBytes = new byte[3][]; + private int[] rgbBytes = null; + private int yRowStride; + private Runnable postInferenceCallback; + private Runnable imageConverter; + private LinearLayout bottomSheetLayout; + private LinearLayout gestureLayout; + private BottomSheetBehavior sheetBehavior; + protected TextView recognitionTextView, + recognition1TextView, + recognition2TextView, + recognitionValueTextView, + recognition1ValueTextView, + recognition2ValueTextView; + protected TextView frameValueTextView, + cropValueTextView, + cameraResolutionTextView, + rotationTextView, + inferenceTimeTextView; + protected ImageView bottomSheetArrowImageView; + private ImageView plusImageView, minusImageView; + private Spinner modelSpinner; + private Spinner deviceSpinner; + private TextView threadsTextView; + + //private Model model = Model.QUANTIZED_EFFICIENTNET; + //private Device device = Device.CPU; + private Model model = Model.FLOAT_EFFICIENTNET; + private Device device = Device.GPU; + private int numThreads = -1; + + @Override + protected void onCreate(final Bundle savedInstanceState) { + LOGGER.d("onCreate " + this); + super.onCreate(null); + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); + + setContentView(R.layout.tfe_ic_activity_camera); + + if (hasPermission()) { + setFragment(); + } else { + requestPermission(); + } + + threadsTextView = findViewById(R.id.threads); + plusImageView = findViewById(R.id.plus); + minusImageView = findViewById(R.id.minus); + modelSpinner = findViewById(R.id.model_spinner); + deviceSpinner = findViewById(R.id.device_spinner); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + gestureLayout = findViewById(R.id.gesture_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + + ViewTreeObserver vto = gestureLayout.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + gestureLayout.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + gestureLayout.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = gestureLayout.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) {} + }); + + recognitionTextView = findViewById(R.id.detected_item); + recognitionValueTextView = findViewById(R.id.detected_item_value); + recognition1TextView = findViewById(R.id.detected_item1); + recognition1ValueTextView = findViewById(R.id.detected_item1_value); + recognition2TextView = findViewById(R.id.detected_item2); + recognition2ValueTextView = findViewById(R.id.detected_item2_value); + + frameValueTextView = findViewById(R.id.frame_info); + cropValueTextView = findViewById(R.id.crop_info); + cameraResolutionTextView = findViewById(R.id.view_info); + rotationTextView = findViewById(R.id.rotation_info); + inferenceTimeTextView = findViewById(R.id.inference_info); + + modelSpinner.setOnItemSelectedListener(this); + deviceSpinner.setOnItemSelectedListener(this); + + plusImageView.setOnClickListener(this); + minusImageView.setOnClickListener(this); + + model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); + device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); + numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); + } + + protected int[] getRgbBytes() { + imageConverter.run(); + return rgbBytes; + } + + protected int getLuminanceStride() { + return yRowStride; + } + + protected byte[] getLuminance() { + return yuvBytes[0]; + } + + /** Callback for android.hardware.Camera API */ + @Override + public void onPreviewFrame(final byte[] bytes, final Camera camera) { + if (isProcessingFrame) { + LOGGER.w("Dropping frame!"); + return; + } + + try { + // Initialize the storage bitmaps once when the resolution is known. + if (rgbBytes == null) { + Camera.Size previewSize = camera.getParameters().getPreviewSize(); + previewHeight = previewSize.height; + previewWidth = previewSize.width; + rgbBytes = new int[previewWidth * previewHeight]; + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); + } + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + return; + } + + isProcessingFrame = true; + yuvBytes[0] = bytes; + yRowStride = previewWidth; + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + camera.addCallbackBuffer(bytes); + isProcessingFrame = false; + } + }; + processImage(); + } + + /** Callback for Camera2 API */ + @Override + public void onImageAvailable(final ImageReader reader) { + // We need wait until we have some size from onPreviewSizeChosen + if (previewWidth == 0 || previewHeight == 0) { + return; + } + if (rgbBytes == null) { + rgbBytes = new int[previewWidth * previewHeight]; + } + try { + final Image image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + if (isProcessingFrame) { + image.close(); + return; + } + isProcessingFrame = true; + Trace.beginSection("imageAvailable"); + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + image.close(); + isProcessingFrame = false; + } + }; + + processImage(); + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + Trace.endSection(); + } + + @Override + public synchronized void onStart() { + LOGGER.d("onStart " + this); + super.onStart(); + } + + @Override + public synchronized void onResume() { + LOGGER.d("onResume " + this); + super.onResume(); + + handlerThread = new HandlerThread("inference"); + handlerThread.start(); + handler = new Handler(handlerThread.getLooper()); + } + + @Override + public synchronized void onPause() { + LOGGER.d("onPause " + this); + + handlerThread.quitSafely(); + try { + handlerThread.join(); + handlerThread = null; + handler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + + super.onPause(); + } + + @Override + public synchronized void onStop() { + LOGGER.d("onStop " + this); + super.onStop(); + } + + @Override + public synchronized void onDestroy() { + LOGGER.d("onDestroy " + this); + super.onDestroy(); + } + + protected synchronized void runInBackground(final Runnable r) { + if (handler != null) { + handler.post(r); + } + } + + @Override + public void onRequestPermissionsResult( + final int requestCode, final String[] permissions, final int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == PERMISSIONS_REQUEST) { + if (allPermissionsGranted(grantResults)) { + setFragment(); + } else { + requestPermission(); + } + } + } + + private static boolean allPermissionsGranted(final int[] grantResults) { + for (int result : grantResults) { + if (result != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + private boolean hasPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED; + } else { + return true; + } + } + + private void requestPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA)) { + Toast.makeText( + CameraActivity.this, + "Camera permission is required for this demo", + Toast.LENGTH_LONG) + .show(); + } + requestPermissions(new String[] {PERMISSION_CAMERA}, PERMISSIONS_REQUEST); + } + } + + // Returns true if the device supports the required hardware level, or better. + private boolean isHardwareLevelSupported( + CameraCharacteristics characteristics, int requiredLevel) { + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { + return requiredLevel == deviceLevel; + } + // deviceLevel is not LEGACY, can use numerical sort + return requiredLevel <= deviceLevel; + } + + private String chooseCamera() { + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); + try { + for (final String cameraId : manager.getCameraIdList()) { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + if (map == null) { + continue; + } + + // Fallback to camera1 API for internal cameras that don't have full support. + // This should help with legacy situations where using the camera2 API causes + // distorted or otherwise broken previews. + useCamera2API = + (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) + || isHardwareLevelSupported( + characteristics, CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); + LOGGER.i("Camera API lv2?: %s", useCamera2API); + return cameraId; + } + } catch (CameraAccessException e) { + LOGGER.e(e, "Not allowed to access camera"); + } + + return null; + } + + protected void setFragment() { + String cameraId = chooseCamera(); + + Fragment fragment; + if (useCamera2API) { + CameraConnectionFragment camera2Fragment = + CameraConnectionFragment.newInstance( + new CameraConnectionFragment.ConnectionCallback() { + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + previewHeight = size.getHeight(); + previewWidth = size.getWidth(); + CameraActivity.this.onPreviewSizeChosen(size, rotation); + } + }, + this, + getLayoutId(), + getDesiredPreviewFrameSize()); + + camera2Fragment.setCamera(cameraId); + fragment = camera2Fragment; + } else { + fragment = + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); + } + + getFragmentManager().beginTransaction().replace(R.id.container, fragment).commit(); + } + + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { + // Because of the variable row stride it's not possible to know in + // advance the actual necessary dimensions of the yuv planes. + for (int i = 0; i < planes.length; ++i) { + final ByteBuffer buffer = planes[i].getBuffer(); + if (yuvBytes[i] == null) { + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); + yuvBytes[i] = new byte[buffer.capacity()]; + } + buffer.get(yuvBytes[i]); + } + } + + protected void readyForNextImage() { + if (postInferenceCallback != null) { + postInferenceCallback.run(); + } + } + + protected int getScreenOrientation() { + switch (getWindowManager().getDefaultDisplay().getRotation()) { + case Surface.ROTATION_270: + return 270; + case Surface.ROTATION_180: + return 180; + case Surface.ROTATION_90: + return 90; + default: + return 0; + } + } + + @UiThread + protected void showResultsInTexture(float[] img_array, int imageSizeX, int imageSizeY) { + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + + } + + protected void showResultsInBottomSheet(List results) { + if (results != null && results.size() >= 3) { + Recognition recognition = results.get(0); + if (recognition != null) { + if (recognition.getTitle() != null) recognitionTextView.setText(recognition.getTitle()); + if (recognition.getConfidence() != null) + recognitionValueTextView.setText( + String.format("%.2f", (100 * recognition.getConfidence())) + "%"); + } + + Recognition recognition1 = results.get(1); + if (recognition1 != null) { + if (recognition1.getTitle() != null) recognition1TextView.setText(recognition1.getTitle()); + if (recognition1.getConfidence() != null) + recognition1ValueTextView.setText( + String.format("%.2f", (100 * recognition1.getConfidence())) + "%"); + } + + Recognition recognition2 = results.get(2); + if (recognition2 != null) { + if (recognition2.getTitle() != null) recognition2TextView.setText(recognition2.getTitle()); + if (recognition2.getConfidence() != null) + recognition2ValueTextView.setText( + String.format("%.2f", (100 * recognition2.getConfidence())) + "%"); + } + } + } + + protected void showFrameInfo(String frameInfo) { + frameValueTextView.setText(frameInfo); + } + + protected void showCropInfo(String cropInfo) { + cropValueTextView.setText(cropInfo); + } + + protected void showCameraResolution(String cameraInfo) { + cameraResolutionTextView.setText(cameraInfo); + } + + protected void showRotationInfo(String rotation) { + rotationTextView.setText(rotation); + } + + protected void showInference(String inferenceTime) { + inferenceTimeTextView.setText(inferenceTime); + } + + protected Model getModel() { + return model; + } + + private void setModel(Model model) { + if (this.model != model) { + LOGGER.d("Updating model: " + model); + this.model = model; + onInferenceConfigurationChanged(); + } + } + + protected Device getDevice() { + return device; + } + + private void setDevice(Device device) { + if (this.device != device) { + LOGGER.d("Updating device: " + device); + this.device = device; + final boolean threadsEnabled = device == Device.CPU; + plusImageView.setEnabled(threadsEnabled); + minusImageView.setEnabled(threadsEnabled); + threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); + onInferenceConfigurationChanged(); + } + } + + protected int getNumThreads() { + return numThreads; + } + + private void setNumThreads(int numThreads) { + if (this.numThreads != numThreads) { + LOGGER.d("Updating numThreads: " + numThreads); + this.numThreads = numThreads; + onInferenceConfigurationChanged(); + } + } + + protected abstract void processImage(); + + protected abstract void onPreviewSizeChosen(final Size size, final int rotation); + + protected abstract int getLayoutId(); + + protected abstract Size getDesiredPreviewFrameSize(); + + protected abstract void onInferenceConfigurationChanged(); + + @Override + public void onClick(View v) { + if (v.getId() == R.id.plus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads >= 9) return; + setNumThreads(++numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } else if (v.getId() == R.id.minus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads == 1) { + return; + } + setNumThreads(--numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } + } + + @Override + public void onItemSelected(AdapterView parent, View view, int pos, long id) { + if (parent == modelSpinner) { + setModel(Model.valueOf(parent.getItemAtPosition(pos).toString().toUpperCase())); + } else if (parent == deviceSpinner) { + setDevice(Device.valueOf(parent.getItemAtPosition(pos).toString())); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + // Do nothing. + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..13e5c0dc341a86b1ddd66c4b562e0bf767641b42 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java @@ -0,0 +1,575 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.annotation.SuppressLint; +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.res.Configuration; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.text.TextUtils; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Toast; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.Logger; + +/** + * Camera Connection Fragment that captures images from camera. + * + *

Instantiated by newInstance.

+ */ +@SuppressWarnings("FragmentNotInstantiable") +public class CameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + + /** + * The camera preview size will be chosen to be the smallest frame by pixel size capable of + * containing a DESIRED_SIZE x DESIRED_SIZE square. + */ + private static final int MINIMUM_PREVIEW_SIZE = 320; + + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + private static final String FRAGMENT_DIALOG = "dialog"; + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private final Semaphore cameraOpenCloseLock = new Semaphore(1); + /** A {@link OnImageAvailableListener} to receive frames as they are available. */ + private final OnImageAvailableListener imageListener; + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ + private final Size inputSize; + /** The layout identifier to inflate for this Fragment. */ + private final int layout; + + private final ConnectionCallback cameraConnectionCallback; + private final CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + @Override + public void onCaptureProgressed( + final CameraCaptureSession session, + final CaptureRequest request, + final CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + final CameraCaptureSession session, + final CaptureRequest request, + final TotalCaptureResult result) {} + }; + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + /** The rotation in degrees of the camera sensor from the display. */ + private Integer sensorOrientation; + /** The {@link Size} of camera preview. */ + private Size previewSize; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An {@link ImageReader} that handles preview frame capture. */ + private ImageReader previewReader; + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + @Override + public void onOpened(final CameraDevice cd) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = cd; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(final CameraDevice cd) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + } + + @Override + public void onError(final CameraDevice cd, final int error) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + final Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + @SuppressLint("ValidFragment") + private CameraConnectionFragment( + final ConnectionCallback connectionCallback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + this.cameraConnectionCallback = connectionCallback; + this.imageListener = imageListener; + this.layout = layout; + this.inputSize = inputSize; + } + + /** + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose + * width and height are at least as large as the minimum of both, or an exact match if possible. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param width The minimum desired width + * @param height The minimum desired height + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); + final Size desiredSize = new Size(width, height); + + // Collect the supported resolutions that are at least as big as the preview Surface + boolean exactSizeFound = false; + final List bigEnough = new ArrayList(); + final List tooSmall = new ArrayList(); + for (final Size option : choices) { + if (option.equals(desiredSize)) { + // Set the size but don't return yet so that remaining sizes will still be logged. + exactSizeFound = true; + } + + if (option.getHeight() >= minSize && option.getWidth() >= minSize) { + bigEnough.add(option); + } else { + tooSmall.add(option); + } + } + + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); + + if (exactSizeFound) { + LOGGER.i("Exact size match found."); + return desiredSize; + } + + // Pick the smallest of those, assuming we found any + if (bigEnough.size() > 0) { + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); + return chosenSize; + } else { + LOGGER.e("Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static CameraConnectionFragment newInstance( + final ConnectionCallback callback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + return new CameraConnectionFragment(callback, imageListener, layout, inputSize); + } + + /** + * Shows a {@link Toast} on the UI thread. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); + } + }); + } + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + public void setCamera(String cameraId) { + this.cameraId = cameraId; + } + + /** Sets up member variables related to camera. */ + private void setUpCameraOutputs() { + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of + // garbage capture data. + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + inputSize.getWidth(), + inputSize.getHeight()); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + final int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.tfe_ic_camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + throw new IllegalStateException(getString(R.string.tfe_ic_camera_error)); + } + + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); + } + + /** Opens the camera specified by {@link CameraConnectionFragment#cameraId}. */ + private void openCamera(final int width, final int height) { + setUpCameraOutputs(); + configureTransform(width, height); + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != previewReader) { + previewReader.close(); + previewReader = null; + } + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("ImageListener"); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + final SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + final Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); + + // Create the reader for the preview frames. + previewReader = + ImageReader.newInstance( + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); + + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); + previewRequestBuilder.addTarget(previewReader.getSurface()); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface, previewReader.getSurface()), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + // Flash is automatically enabled when necessary. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + @Override + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** + * Configures the necessary {@link Matrix} transformation to `mTextureView`. This method should be + * called after the camera preview size is determined in setUpCameraOutputs and also the size of + * `mTextureView` is fixed. + * + * @param viewWidth The width of `mTextureView` + * @param viewHeight The height of `mTextureView` + */ + private void configureTransform(final int viewWidth, final int viewHeight) { + final Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + final Matrix matrix = new Matrix(); + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + final float centerX = viewRect.centerX(); + final float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + final float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** + * Callback for Activities to use to initialize their data once the selected preview size is + * known. + */ + public interface ConnectionCallback { + void onPreviewSizeChosen(Size size, int cameraRotation); + } + + /** Compares two {@code Size}s based on their areas. */ + static class CompareSizesByArea implements Comparator { + @Override + public int compare(final Size lhs, final Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(final String message) { + final ErrorDialog dialog = new ErrorDialog(); + final Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(final Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(final DialogInterface dialogInterface, final int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..24b5d72fdb42d47e5d2c87e3f944b71105748c1b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -0,0 +1,238 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Typeface; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.util.Size; +import android.util.TypedValue; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; + +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.BorderedText; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; + +import android.widget.ImageView; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Rect; +import android.graphics.RectF; +import android.graphics.PixelFormat; +import java.nio.ByteBuffer; + +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); + private static final float TEXT_SIZE_DIP = 10; + private Bitmap rgbFrameBitmap = null; + private long lastProcessingTimeMs; + private Integer sensorOrientation; + private Classifier classifier; + private BorderedText borderedText; + /** Input image size of the model along x axis. */ + private int imageSizeX; + /** Input image size of the model along y axis. */ + private int imageSizeY; + + @Override + protected int getLayoutId() { + return R.layout.tfe_ic_camera_connection_fragment; + } + + @Override + protected Size getDesiredPreviewFrameSize() { + return DESIRED_PREVIEW_SIZE; + } + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + borderedText.setTypeface(Typeface.MONOSPACE); + + recreateClassifier(getModel(), getDevice(), getNumThreads()); + if (classifier == null) { + LOGGER.e("No classifier on preview!"); + return; + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + sensorOrientation = rotation - getScreenOrientation(); + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + } + + @Override + protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); + final int cropSize = Math.min(previewWidth, previewHeight); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + //final List results = + // classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + final List results = new ArrayList<>(); + + float[] img_array = classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + + + /* + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + */ + + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + //showResultsInBottomSheet(results); + showResultsInTexture(img_array, imageSizeX, imageSizeY); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(cropSize + "x" + cropSize); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); + } + + @Override + protected void onInferenceConfigurationChanged() { + if (rgbFrameBitmap == null) { + // Defer creation until we're getting camera frames. + return; + } + final Device device = getDevice(); + final Model model = getModel(); + final int numThreads = getNumThreads(); + runInBackground(() -> recreateClassifier(model, device, numThreads)); + } + + private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU + && (model == Model.QUANTIZED_MOBILENET || model == Model.QUANTIZED_EFFICIENTNET)) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, R.string.tfe_ic_gpu_quant_error, Toast.LENGTH_LONG).show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException | IllegalArgumentException e) { + LOGGER.e(e, "Failed to create classifier."); + runOnUiThread( + () -> { + Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show(); + }); + return; + } + + // Updates the input image size. + imageSizeX = classifier.getImageSizeX(); + imageSizeY = classifier.getImageSizeY(); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..760fe90375450c7b1356603c83fb37a68548ca13 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java @@ -0,0 +1,203 @@ +package org.tensorflow.lite.examples.classification; + +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import android.annotation.SuppressLint; +import android.app.Fragment; +import android.graphics.SurfaceTexture; +import android.hardware.Camera; +import android.hardware.Camera.CameraInfo; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import java.io.IOException; +import java.util.List; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; + +public class LegacyCameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + private Camera camera; + private Camera.PreviewCallback imageListener; + private Size desiredSize; + /** The layout identifier to inflate for this Fragment. */ + private int layout; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + + int index = getCameraId(); + camera = Camera.open(index); + + try { + Camera.Parameters parameters = camera.getParameters(); + List focusModes = parameters.getSupportedFocusModes(); + if (focusModes != null + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); + } + List cameraSizes = parameters.getSupportedPreviewSizes(); + Size[] sizes = new Size[cameraSizes.size()]; + int i = 0; + for (Camera.Size size : cameraSizes) { + sizes[i++] = new Size(size.width, size.height); + } + Size previewSize = + CameraConnectionFragment.chooseOptimalSize( + sizes, desiredSize.getWidth(), desiredSize.getHeight()); + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); + camera.setDisplayOrientation(90); + camera.setParameters(parameters); + camera.setPreviewTexture(texture); + } catch (IOException exception) { + camera.release(); + } + + camera.setPreviewCallbackWithBuffer(imageListener); + Camera.Size s = camera.getParameters().getPreviewSize(); + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); + + textureView.setAspectRatio(s.height, s.width); + + camera.startPreview(); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) {} + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + @SuppressLint("ValidFragment") + public LegacyCameraConnectionFragment( + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { + this.imageListener = imageListener; + this.layout = layout; + this.desiredSize = desiredSize; + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + + if (textureView.isAvailable()) { + if (camera != null) { + camera.startPreview(); + } + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + stopCamera(); + stopBackgroundThread(); + super.onPause(); + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("CameraBackground"); + backgroundThread.start(); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + protected void stopCamera() { + if (camera != null) { + camera.stopPreview(); + camera.setPreviewCallback(null); + camera.release(); + camera = null; + } + } + + private int getCameraId() { + CameraInfo ci = new CameraInfo(); + for (int i = 0; i < Camera.getNumberOfCameras(); i++) { + Camera.getCameraInfo(i, ci); + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) return i; + } + return -1; // No camera found + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..62e99ae70c2a7c4c60a776e7490742c5339e85f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + private int ratioWidth = 0; + private int ratioHeight = 0; + + public AutoFitTextureView(final Context context) { + this(context, null); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(final int width, final int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + ratioWidth = width; + ratioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + final int width = MeasureSpec.getSize(widthMeasureSpec); + final int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == ratioWidth || 0 == ratioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * ratioWidth / ratioHeight) { + setMeasuredDimension(width, width * ratioHeight / ratioWidth); + } else { + setMeasuredDimension(height * ratioWidth / ratioHeight, height); + } + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java new file mode 100644 index 0000000000000000000000000000000000000000..dc302ac04f145c9a1673a2d7e630a94a05ab1b1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.util.AttributeSet; +import android.view.View; +import java.util.LinkedList; +import java.util.List; + +/** A simple View providing a render callback to other classes. */ +public class OverlayView extends View { + private final List callbacks = new LinkedList(); + + public OverlayView(final Context context, final AttributeSet attrs) { + super(context, attrs); + } + + public void addCallback(final DrawCallback callback) { + callbacks.add(callback); + } + + @Override + public synchronized void draw(final Canvas canvas) { + for (final DrawCallback callback : callbacks) { + callback.drawCallback(canvas); + } + } + + /** Interface defining the callback for client classes. */ + public interface DrawCallback { + public void drawCallback(final Canvas canvas); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java new file mode 100644 index 0000000000000000000000000000000000000000..2c57f603f12200079c888793cfa40d9b10dabde3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.graphics.Paint; +import android.util.AttributeSet; +import android.util.TypedValue; +import android.view.View; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public class RecognitionScoreView extends View implements ResultsView { + private static final float TEXT_SIZE_DIP = 16; + private final float textSizePx; + private final Paint fgPaint; + private final Paint bgPaint; + private List results; + + public RecognitionScoreView(final Context context, final AttributeSet set) { + super(context, set); + + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + fgPaint = new Paint(); + fgPaint.setTextSize(textSizePx); + + bgPaint = new Paint(); + bgPaint.setColor(0xcc4285f4); + } + + @Override + public void setResults(final List results) { + this.results = results; + postInvalidate(); + } + + @Override + public void onDraw(final Canvas canvas) { + final int x = 10; + int y = (int) (fgPaint.getTextSize() * 1.5f); + + canvas.drawPaint(bgPaint); + + if (results != null) { + for (final Recognition recog : results) { + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); + y += (int) (fgPaint.getTextSize() * 1.5f); + } + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java new file mode 100644 index 0000000000000000000000000000000000000000..d055eb5f161a57fc439716efe6d49b7e45ef3fc7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public interface ResultsView { + public void setResults(final List results); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 0000000000000000000000000000000000000000..b1517edf496ef5800b97d046b92012a9f94a34d0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 0000000000000000000000000000000000000000..70f4b24e35039e6bfc35989bcbe570a4bdc2ae07 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml new file mode 100644 index 0000000000000000000000000000000000000000..757f4503314fb9e5837f68ac515f4487d9b5fc2c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml @@ -0,0 +1,9 @@ + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml new file mode 100644 index 0000000000000000000000000000000000000000..a64b853e79137f0fd95f9d5fa6e0552cc255c7ae --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml @@ -0,0 +1,9 @@ + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000000000000000000000000000000000000..d5fccc538c179838bfdce779c26eebb4fa0b5ce9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml new file mode 100644 index 0000000000000000000000000000000000000000..b8f5d3559c4e83072d5d73a3241d240aa68daccf --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml @@ -0,0 +1,13 @@ + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..f0e1dae7afa15cf4a832de708f345482a6dfeff6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml new file mode 100644 index 0000000000000000000000000000000000000000..97e5e7c6df25da48977f9064a888fd3735e4986f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml @@ -0,0 +1,32 @@ + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml new file mode 100644 index 0000000000000000000000000000000000000000..77a348af90e2ed995ff106cd209cbf304c6b9153 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..ed82bafb536474c6a88c996b439a2781f31f3d3e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml @@ -0,0 +1,8 @@ + + + #ffa800 + #ff6f00 + #425066 + + #66000000 + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d3609029ca66b612c88b4f395e4e2e3cfc1f0e6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml @@ -0,0 +1,5 @@ + + + 15dp + 8dp + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..7d763d85efc49879c8d3c0641484f5f472bfaca0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml @@ -0,0 +1,21 @@ + + Midas + This device doesn\'t support Camera2 API. + GPU does not yet supported quantized models. + Model: + + Float_EfficientNet + + + + Device: + + GPU + CPU + NNAPI + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..ad09a13ec6b2de8920a7441c9992f3cc0eedcfda --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..14492756847191ca3beff4c2e012d378c4e44be6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:4.0.0' + classpath 'de.undercouch:gradle-download-task:4.0.2' + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..9592636c07d9d5e6f61c0cfce1311d3e1ffcf34d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties @@ -0,0 +1,15 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +android.useAndroidX=true +android.enableJetifier=true diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..f3d88b1c2faf2fc91d853cd5d4242b5547257070 Binary files /dev/null and b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..1b16c34a71cf212ed0cfb883d14d1b8511903eb2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..2fe81a7d95e4f9ad2c9b2a046707d36ceb3980b3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew @@ -0,0 +1,183 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..9618d8d9607cd91a0efb866bcac4810064ba6fac --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat @@ -0,0 +1,100 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..5d463975293264765a941795601cddb6cfc84f00 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite + implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true } + // Use local TensorFlow library + // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..24ec573e7d184e7d64118a723d6645fd92d6e6d9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,376 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import android.view.TextureView; +import android.view.ViewStub; + +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.TensorProcessor; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.image.ops.ResizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithSupport"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + QUANTIZED_EFFICIENTNET, + FLOAT_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** The loaded TensorFlow Lite model. */ + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + + /** Optional GPU delegate for accleration. */ + private GpuDelegate gpuDelegate = null; + + /** Optional NNAPI delegate for accleration. */ + private NnApiDelegate nnApiDelegate = null; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected Interpreter tflite; + + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** Labels corresponding to the output of the vision model. */ + private final List labels; + + /** Input image TensorBuffer. */ + private TensorImage inputImageBuffer; + + /** Output probability TensorBuffer. */ + private final TensorBuffer outputProbabilityBuffer; + + /** Processer to apply post processing of the output probability. */ + private final TensorProcessor probabilityProcessor; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + + // Loads labels out from the label file. + labels = FileUtil.loadLabels(activity, getLabelPath()); + + // Reads type and shape of input and output tensors, respectively. + int imageTensorIndex = 0; + int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3} + if(imageShape[1] != imageShape[2]) { + imageSizeY = imageShape[2]; + imageSizeX = imageShape[3]; + } else { + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType(); + int probabilityTensorIndex = 0; + int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES} + DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType(); + + // Creates the input tensor. + inputImageBuffer = new TensorImage(imageDataType); + + // Creates the output tensor and its processor. + outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + + // Creates the post processor for the output probability. + probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); + + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Runs inference and returns the classification results. */ + //public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + public float[] recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("loadImage"); + long startTimeForLoadImage = SystemClock.uptimeMillis(); + inputImageBuffer = loadImage(bitmap, sensorOrientation); + long endTimeForLoadImage = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage)); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind()); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + float[] img_array = outputProbabilityBuffer.getFloatArray(); + + // Gets the map of label and probability. + //Map labeledProbability = + // new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer)) + // .getMapWithFloatValue(); + Trace.endSection(); + + // Gets top-k results. + return img_array;//getTopKProbability(labeledProbability); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (tflite != null) { + tflite.close(); + tflite = null; + } + if (gpuDelegate != null) { + gpuDelegate.close(); + gpuDelegate = null; + } + if (nnApiDelegate != null) { + nnApiDelegate.close(); + nnApiDelegate = null; + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** Loads input image, and applies preprocessing. */ + private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + inputImageBuffer.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = min(bitmap.getWidth(), bitmap.getHeight()); + int numRotation = sensorOrientation / 90; + // TODO(b/143564309): Fuse ops inside ImageProcessor. + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // To get the same inference results as lib_task_api, which is built on top of the Task + // Library, use ResizeMethod.BILINEAR. + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.NEAREST_NEIGHBOR)) + //.add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)) + .add(new Rot90Op(numRotation)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); + } + + /** Gets the top-k results. */ + private static List getTopKProbability(Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); + + /** Gets the name of the label file stored in Assets. */ + protected abstract String getLabelPath(); + + /** Gets the TensorOperator to nomalize the input image in preprocessing. */ + protected abstract TensorOperator getPreprocessNormalizeOp(); + + /** + * Gets the TensorOperator to dequantize the output probability in post processing. + * + *

For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all + * essentially linear transformation). For float model, de-quantize is not required. But to + * uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and + * 1.0f, respectively. + */ + protected abstract TensorOperator getPostprocessNormalizeOp(); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..14dd027b26baefaedd979a8ac37f0bf984210ed4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + private static final float IMAGE_MEAN = 115.0f; //127.0f; + private static final float IMAGE_STD = 58.0f; //128.0f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model_opt.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..40519de07cf5e887773250a4609a832b6060d684 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + + /** Float MobileNet requires additional normalization of the used input. */ + private static final float IMAGE_MEAN = 127.5f; + + private static final float IMAGE_STD = 127.5f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..d0d62f58d18333b6360ec30a4c85c9f1d38955ce --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..94b06e3df659005c287733a8a37672863fdadd71 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b5983986e3d56a77a41676b9195b0d0882b5fb96 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite Task Library + implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..45da52a0d0dfa203255e0f2d44901ee0618e739f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.support.metadata.MetadataExtractor; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions.Orientation; +import org.tensorflow.lite.task.vision.classifier.Classifications; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier.ImageClassifierOptions; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithTaskApi"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + FLOAT_EFFICIENTNET, + QUANTIZED_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected final ImageClassifier imageClassifier; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + if (device != Device.CPU || numThreads != 1) { + throw new IllegalArgumentException( + "Manipulating the hardware accelerators and numbers of threads is not allowed in the Task" + + " library currently. Only CPU + single thread is allowed."); + } + + // Create the ImageClassifier instance. + ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); + imageClassifier = ImageClassifier.createFromFileAndOptions(activity, getModelPath(), options); + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + + // Get the input image size information of the underlying tflite model. + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + MetadataExtractor metadataExtractor = new MetadataExtractor(tfliteModel); + // Image shape is in the format of {1, height, width, 3}. + int[] imageShape = metadataExtractor.getInputTensorShape(/*inputIndex=*/ 0); + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + + /** Runs inference and returns the classification results. */ + public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + TensorImage inputImage = TensorImage.fromBitmap(bitmap); + int width = bitmap.getWidth(); + int height = bitmap.getHeight(); + int cropSize = min(width, height); + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // Task Library resize the images using bilinear interpolation, which is slightly different from + // the nearest neighbor sampling algorithm used in lib_support. See + // https://github.com/tensorflow/examples/blob/0ef3d93e2af95d325c70ef3bcbbd6844d0631e07/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java#L310. + ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + List results = imageClassifier.classify(inputImage, imageOptions); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + Trace.endSection(); + + return getRecognitions(results); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (imageClassifier != null) { + imageClassifier.close(); + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** + * Converts a list of {@link Classifications} objects into a list of {@link Recognition} objects + * to match the interface of other inference method, such as using the TFLite + * Support Library.. + */ + private static List getRecognitions(List classifications) { + + final ArrayList recognitions = new ArrayList<>(); + // All the demo models are single head models. Get the first Classifications in the results. + for (Category category : classifications.get(0).getCategories()) { + recognitions.add( + new Recognition( + "" + category.getLabel(), category.getLabel(), category.getScore(), null)); + } + return recognitions; + } + + /* Convert the camera orientation in degree into {@link ImageProcessingOptions#Orientation}.*/ + private static Orientation getOrientation(int cameraOrientation) { + switch (cameraOrientation / 90) { + case 3: + return Orientation.BOTTOM_LEFT; + case 2: + return Orientation.BOTTOM_RIGHT; + case 1: + return Orientation.TOP_RIGHT; + default: + return Orientation.TOP_LEFT; + } + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..250794cc12d0e603aa47502322dc646d50689848 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..0707de98de41395eaf3ddcfd74d6e36229a63760 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..05ca4fa6c409d0274a396c9b26c3c39ca8a8194e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "efficientnet-lite0-int8.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..978b08eeaf52a23eede437d61045db08d1dff163 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224_quant.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8d825707af20cbbead6c4599f075599148e3511c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle @@ -0,0 +1,40 @@ +apply plugin: 'com.android.library' +apply plugin: 'de.undercouch.download' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download.gradle' diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle new file mode 100644 index 0000000000000000000000000000000000000000..ce76974a2c3bc6f8214461028e0dfa6ebc25d588 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle @@ -0,0 +1,10 @@ +def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite" +def modelFloatFile = "model_opt.tflite" + +task downloadModelFloat(type: Download) { + src "${modelFloatDownloadUrl}" + dest project.ext.ASSET_DIR + "/${modelFloatFile}" + overwrite false +} + +preBuild.dependsOn downloadModelFloat diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..42951a56497c5f947efe4aea6a07462019fb152c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt new file mode 100644 index 0000000000000000000000000000000000000000..f40829ed0fc318c673860fae4be6c48529da116e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt @@ -0,0 +1,1000 @@ +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8ebe235758d3d0f3d357c51ed54d78ac7eea8e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py @@ -0,0 +1,75 @@ +# Flex ops are included in the nightly build of the TensorFlow Python package. You can use TFLite models containing Flex ops by the same Python API as normal TFLite models. The nightly TensorFlow build can be installed with this command: +# Flex ops will be added to the TensorFlow Python package's and the tflite_runtime package from version 2.3 for Linux and 2.4 for other environments. +# https://www.tensorflow.org/lite/guide/ops_select#running_the_model + +# You must use: tf-nightly +# pip install tf-nightly + +import os +import glob +import cv2 +import numpy as np + +import tensorflow as tf + +width=256 +height=256 +model_name="model.tflite" +#model_name="model_quant.tflite" +image_name="dog.jpg" + +# input +img = cv2.imread(image_name) +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + +mean=[0.485, 0.456, 0.406] +std=[0.229, 0.224, 0.225] +img = (img - mean) / std + +img_resized = tf.image.resize(img, [width,height], method='bicubic', preserve_aspect_ratio=False) +#img_resized = tf.transpose(img_resized, [2, 0, 1]) +img_input = img_resized.numpy() +reshape_img = img_input.reshape(1,width,height,3) +tensor = tf.convert_to_tensor(reshape_img, dtype=tf.float32) + +# load model +print("Load model...") +interpreter = tf.lite.Interpreter(model_path=model_name) +print("Allocate tensor...") +interpreter.allocate_tensors() +print("Get input/output details...") +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() +print("Get input shape...") +input_shape = input_details[0]['shape'] +print(input_shape) +print(input_details) +print(output_details) +#input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +print("Set input tensor...") +interpreter.set_tensor(input_details[0]['index'], tensor) + +print("invoke()...") +interpreter.invoke() + +# The function `get_tensor()` returns a copy of the tensor data. +# Use `tensor()` in order to get a pointer to the tensor. +print("get output tensor...") +output = interpreter.get_tensor(output_details[0]['index']) +#output = np.squeeze(output) +output = output.reshape(width, height) +#print(output) +prediction = np.array(output) +print("reshape prediction...") +prediction = prediction.reshape(width, height) + +# output file +#prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) +print(" Write image to: output.png") +depth_min = prediction.min() +depth_max = prediction.max() +img_out = (255 * (prediction - depth_min) / (depth_max - depth_min)).astype("uint8") +print("save output image...") +cv2.imwrite("output.png", img_out) + +print("finished") \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e86d89d2483f92b7e778589011fad60fbba3a318 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'TFLite Image Classification Demo App' +include ':app', ':lib_support', ':lib_task_api', ':models' \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f1150e3379e4a38d31ca7bb46dc4f31d79f482c2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore @@ -0,0 +1,2 @@ +# ignore model file +#*.tflite diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..4917371aa33a65fdfc66c02d914f05489c446430 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj @@ -0,0 +1,538 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = A1CE41C09920CCEC31985547 /* Pods_Midas.framework */; }; + 8402440123D9834600704ABD /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 8402440023D9834600704ABD /* README.md */; }; + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 840ECB1F238BAA2300C7D88A /* InfoCell.swift */; }; + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */; }; + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDD002341DE380017ED42 /* Main.storyboard */; }; + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 842DDB6D2372A82000F6BB94 /* OverlayView.swift */; }; + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */; }; + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846BAF7523E7FE13006FC136 /* Constants.swift */; }; + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FEC82341D36E00377D34 /* PreviewView.swift */; }; + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FECA2341D39800377D34 /* CameraFeedManager.swift */; }; + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */; }; + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB82361874A0052C104 /* TFLiteExtension.swift */; }; + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CEE2326338300A11A08 /* AppDelegate.swift */; }; + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CF02326338300A11A08 /* ViewController.swift */; }; + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 84B67CF52326338400A11A08 /* Assets.xcassets */; }; + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */; }; + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 84F232D4254C831E0011862E /* model_opt.tflite */; }; + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */ = {isa = PBXBuildFile; fileRef = 84FCF5912387BD7900663812 /* tfl_logo.png */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 8402440023D9834600704ABD /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InfoCell.swift; sourceTree = ""; }; + 840EDCFC2341DDD30017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = "Base.lproj/Launch Screen.storyboard"; sourceTree = ""; }; + 840EDD012341DE380017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayView.swift; sourceTree = ""; }; + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDataHandler.swift; sourceTree = ""; }; + 846BAF7523E7FE13006FC136 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; + 8474FEC82341D36E00377D34 /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; }; + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CameraFeedManager.swift; sourceTree = ""; }; + 84884291236FF0A30043FC4C /* download_models.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = download_models.sh; sourceTree = ""; }; + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CVPixelBufferExtension.swift; sourceTree = ""; }; + 84952CB82361874A0052C104 /* TFLiteExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteExtension.swift; sourceTree = ""; }; + 84B67CEB2326338300A11A08 /* Midas.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Midas.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 84B67CEE2326338300A11A08 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; }; + 84B67CF02326338300A11A08 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; }; + 84B67CF52326338400A11A08 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 84B67CFA2326338400A11A08 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CGSizeExtension.swift; sourceTree = ""; }; + 84F232D4254C831E0011862E /* model_opt.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = model_opt.tflite; sourceTree = ""; }; + 84FCF5912387BD7900663812 /* tfl_logo.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; name = tfl_logo.png; path = Assets.xcassets/tfl_logo.png; sourceTree = ""; }; + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Midas.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.release.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.release.xcconfig"; sourceTree = ""; }; + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.debug.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.debug.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 84B67CE82326338300A11A08 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 840ECB1E238BAA0D00C7D88A /* Cells */ = { + isa = PBXGroup; + children = ( + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */, + ); + path = Cells; + sourceTree = ""; + }; + 842DDB6C2372A80E00F6BB94 /* Views */ = { + isa = PBXGroup; + children = ( + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */, + ); + path = Views; + sourceTree = ""; + }; + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */ = { + isa = PBXGroup; + children = ( + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */, + ); + path = ModelDataHandler; + sourceTree = ""; + }; + 8474FEC62341D2BE00377D34 /* ViewControllers */ = { + isa = PBXGroup; + children = ( + 84B67CF02326338300A11A08 /* ViewController.swift */, + ); + path = ViewControllers; + sourceTree = ""; + }; + 8474FEC72341D35800377D34 /* Camera Feed */ = { + isa = PBXGroup; + children = ( + 8474FEC82341D36E00377D34 /* PreviewView.swift */, + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */, + ); + path = "Camera Feed"; + sourceTree = ""; + }; + 84884290236FF07F0043FC4C /* RunScripts */ = { + isa = PBXGroup; + children = ( + 84884291236FF0A30043FC4C /* download_models.sh */, + ); + path = RunScripts; + sourceTree = ""; + }; + 848842A22370180C0043FC4C /* Model */ = { + isa = PBXGroup; + children = ( + 84F232D4254C831E0011862E /* model_opt.tflite */, + ); + path = Model; + sourceTree = ""; + }; + 84952CB3236186A20052C104 /* Extensions */ = { + isa = PBXGroup; + children = ( + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */, + 84952CB82361874A0052C104 /* TFLiteExtension.swift */, + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */, + ); + path = Extensions; + sourceTree = ""; + }; + 84B67CE22326338300A11A08 = { + isa = PBXGroup; + children = ( + 8402440023D9834600704ABD /* README.md */, + 84884290236FF07F0043FC4C /* RunScripts */, + 84B67CED2326338300A11A08 /* Midas */, + 84B67CEC2326338300A11A08 /* Products */, + B4DFDCC28443B641BC36251D /* Pods */, + A3DA804B8D3F6891E3A02852 /* Frameworks */, + ); + sourceTree = ""; + }; + 84B67CEC2326338300A11A08 /* Products */ = { + isa = PBXGroup; + children = ( + 84B67CEB2326338300A11A08 /* Midas.app */, + ); + name = Products; + sourceTree = ""; + }; + 84B67CED2326338300A11A08 /* Midas */ = { + isa = PBXGroup; + children = ( + 840ECB1E238BAA0D00C7D88A /* Cells */, + 842DDB6C2372A80E00F6BB94 /* Views */, + 848842A22370180C0043FC4C /* Model */, + 84952CB3236186A20052C104 /* Extensions */, + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */, + 8474FEC72341D35800377D34 /* Camera Feed */, + 8474FEC62341D2BE00377D34 /* ViewControllers */, + 84B67D002326339000A11A08 /* Storyboards */, + 84B67CEE2326338300A11A08 /* AppDelegate.swift */, + 846BAF7523E7FE13006FC136 /* Constants.swift */, + 84B67CF52326338400A11A08 /* Assets.xcassets */, + 84FCF5912387BD7900663812 /* tfl_logo.png */, + 84B67CFA2326338400A11A08 /* Info.plist */, + ); + path = Midas; + sourceTree = ""; + }; + 84B67D002326339000A11A08 /* Storyboards */ = { + isa = PBXGroup; + children = ( + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */, + 840EDD002341DE380017ED42 /* Main.storyboard */, + ); + path = Storyboards; + sourceTree = ""; + }; + A3DA804B8D3F6891E3A02852 /* Frameworks */ = { + isa = PBXGroup; + children = ( + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; + B4DFDCC28443B641BC36251D /* Pods */ = { + isa = PBXGroup; + children = ( + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */, + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */, + ); + path = Pods; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 84B67CEA2326338300A11A08 /* Midas */ = { + isa = PBXNativeTarget; + buildConfigurationList = 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */; + buildPhases = ( + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */, + 84884298237010B90043FC4C /* Download TensorFlow Lite model */, + 84B67CE72326338300A11A08 /* Sources */, + 84B67CE82326338300A11A08 /* Frameworks */, + 84B67CE92326338300A11A08 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = Midas; + productName = Midas; + productReference = 84B67CEB2326338300A11A08 /* Midas.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 84B67CE32326338300A11A08 /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 1030; + LastUpgradeCheck = 1030; + ORGANIZATIONNAME = tensorflow; + TargetAttributes = { + 84B67CEA2326338300A11A08 = { + CreatedOnToolsVersion = 10.3; + }; + }; + }; + buildConfigurationList = 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 84B67CE22326338300A11A08; + productRefGroup = 84B67CEC2326338300A11A08 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 84B67CEA2326338300A11A08 /* Midas */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 84B67CE92326338300A11A08 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 8402440123D9834600704ABD /* README.md in Resources */, + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */, + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */, + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */, + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */, + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-Midas-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + 84884298237010B90043FC4C /* Download TensorFlow Lite model */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "Download TensorFlow Lite model"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/bash; + shellScript = "\"$SRCROOT/RunScripts/download_models.sh\"\n"; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 84B67CE72326338300A11A08 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */, + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */, + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */, + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */, + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */, + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */, + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */, + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */, + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */, + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */, + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDCFC2341DDD30017ED42 /* Base */, + ); + name = "Launch Screen.storyboard"; + sourceTree = ""; + }; + 840EDD002341DE380017ED42 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDD012341DE380017ED42 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 84B67CFB2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 84B67CFC2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 84B67CFE2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 84B67CFF2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFB2326338400A11A08 /* Debug */, + 84B67CFC2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFE2326338400A11A08 /* Debug */, + 84B67CFF2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 84B67CE32326338300A11A08 /* Project object */; +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000000000000000000000000000000000000..919434a6254f0e9651f402737811be6634a03e9c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000000000000000000000000000000000000..18d981003d68d0546c4804ac2ff47dd97c6e7921 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate new file mode 100644 index 0000000000000000000000000000000000000000..1d20756ee57b79e9f9f886453bdb7997ca2ee2d4 Binary files /dev/null and b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist new file mode 100644 index 0000000000000000000000000000000000000000..6093f6160eedfdfc20e96396247a7dbc9247cc55 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist @@ -0,0 +1,14 @@ + + + + + SchemeUserState + + PoseNet.xcscheme_^#shared#^_ + + orderHint + 3 + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift new file mode 100644 index 0000000000000000000000000000000000000000..233f0291ab4f379067543bdad3cc198a2dc3ab0f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift @@ -0,0 +1,41 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +@UIApplicationMain +class AppDelegate: UIResponder, UIApplicationDelegate { + + var window: UIWindow? + + func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool { + return true + } + + func applicationWillResignActive(_ application: UIApplication) { + } + + func applicationDidEnterBackground(_ application: UIApplication) { + } + + func applicationWillEnterForeground(_ application: UIApplication) { + } + + func applicationDidBecomeActive(_ application: UIApplication) { + } + + func applicationWillTerminate(_ application: UIApplication) { + } +} + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..65b74d7ef11fa59fafa829e681ac90906f3ac8b2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1 @@ +{"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift new file mode 100644 index 0000000000000000000000000000000000000000..48d65b88ee220e722fbad2570c8e879a431cd0f5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift @@ -0,0 +1,316 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + +// MARK: - CameraFeedManagerDelegate Declaration +@objc protocol CameraFeedManagerDelegate: class { + /// This method delivers the pixel buffer of the current frame seen by the device's camera. + @objc optional func cameraFeedManager( + _ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer + ) + + /// This method initimates that a session runtime error occured. + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) + + /// This method initimates that the session was interrupted. + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) + + /// This method initimates that the session interruption has ended. + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) + + /// This method initimates that there was an error in video configurtion. + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) + + /// This method initimates that the camera permissions have been denied. + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) +} + +/// This enum holds the state of the camera initialization. +// MARK: - Camera Initialization State Enum +enum CameraConfiguration { + case success + case failed + case permissionDenied +} + +/// This class manages all camera related functionalities. +// MARK: - Camera Related Functionalies Manager +class CameraFeedManager: NSObject { + // MARK: Camera Related Instance Variables + private let session: AVCaptureSession = AVCaptureSession() + + private let previewView: PreviewView + private let sessionQueue = DispatchQueue(label: "sessionQueue") + private var cameraConfiguration: CameraConfiguration = .failed + private lazy var videoDataOutput = AVCaptureVideoDataOutput() + private var isSessionRunning = false + + // MARK: CameraFeedManagerDelegate + weak var delegate: CameraFeedManagerDelegate? + + // MARK: Initializer + init(previewView: PreviewView) { + self.previewView = previewView + super.init() + + // Initializes the session + session.sessionPreset = .high + self.previewView.session = session + self.previewView.previewLayer.connection?.videoOrientation = .portrait + self.previewView.previewLayer.videoGravity = .resizeAspectFill + self.attemptToConfigureSession() + } + + // MARK: Session Start and End methods + + /// This method starts an AVCaptureSession based on whether the camera configuration was successful. + func checkCameraConfigurationAndStartSession() { + sessionQueue.async { + switch self.cameraConfiguration { + case .success: + self.addObservers() + self.startSession() + case .failed: + DispatchQueue.main.async { + self.delegate?.presentVideoConfigurationErrorAlert(self) + } + case .permissionDenied: + DispatchQueue.main.async { + self.delegate?.presentCameraPermissionsDeniedAlert(self) + } + } + } + } + + /// This method stops a running an AVCaptureSession. + func stopSession() { + self.removeObservers() + sessionQueue.async { + if self.session.isRunning { + self.session.stopRunning() + self.isSessionRunning = self.session.isRunning + } + } + + } + + /// This method resumes an interrupted AVCaptureSession. + func resumeInterruptedSession(withCompletion completion: @escaping (Bool) -> Void) { + sessionQueue.async { + self.startSession() + + DispatchQueue.main.async { + completion(self.isSessionRunning) + } + } + } + + /// This method starts the AVCaptureSession + private func startSession() { + self.session.startRunning() + self.isSessionRunning = self.session.isRunning + } + + // MARK: Session Configuration Methods. + /// This method requests for camera permissions and handles the configuration of the session and stores the result of configuration. + private func attemptToConfigureSession() { + switch AVCaptureDevice.authorizationStatus(for: .video) { + case .authorized: + self.cameraConfiguration = .success + case .notDetermined: + self.sessionQueue.suspend() + self.requestCameraAccess(completion: { granted in + self.sessionQueue.resume() + }) + case .denied: + self.cameraConfiguration = .permissionDenied + default: + break + } + + self.sessionQueue.async { + self.configureSession() + } + } + + /// This method requests for camera permissions. + private func requestCameraAccess(completion: @escaping (Bool) -> Void) { + AVCaptureDevice.requestAccess(for: .video) { (granted) in + if !granted { + self.cameraConfiguration = .permissionDenied + } else { + self.cameraConfiguration = .success + } + completion(granted) + } + } + + /// This method handles all the steps to configure an AVCaptureSession. + private func configureSession() { + guard cameraConfiguration == .success else { + return + } + session.beginConfiguration() + + // Tries to add an AVCaptureDeviceInput. + guard addVideoDeviceInput() == true else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + // Tries to add an AVCaptureVideoDataOutput. + guard addVideoDataOutput() else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + session.commitConfiguration() + self.cameraConfiguration = .success + } + + /// This method tries to an AVCaptureDeviceInput to the current AVCaptureSession. + private func addVideoDeviceInput() -> Bool { + /// Tries to get the default back camera. + guard + let camera = AVCaptureDevice.default(.builtInWideAngleCamera, for: .video, position: .back) + else { + fatalError("Cannot find camera") + } + + do { + let videoDeviceInput = try AVCaptureDeviceInput(device: camera) + if session.canAddInput(videoDeviceInput) { + session.addInput(videoDeviceInput) + return true + } else { + return false + } + } catch { + fatalError("Cannot create video device input") + } + } + + /// This method tries to an AVCaptureVideoDataOutput to the current AVCaptureSession. + private func addVideoDataOutput() -> Bool { + let sampleBufferQueue = DispatchQueue(label: "sampleBufferQueue") + videoDataOutput.setSampleBufferDelegate(self, queue: sampleBufferQueue) + videoDataOutput.alwaysDiscardsLateVideoFrames = true + videoDataOutput.videoSettings = [ + String(kCVPixelBufferPixelFormatTypeKey): kCMPixelFormat_32BGRA + ] + + if session.canAddOutput(videoDataOutput) { + session.addOutput(videoDataOutput) + videoDataOutput.connection(with: .video)?.videoOrientation = .portrait + return true + } + return false + } + + // MARK: Notification Observer Handling + private func addObservers() { + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionRuntimeErrorOccured(notification:)), + name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionWasInterrupted(notification:)), + name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionInterruptionEnded), + name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + private func removeObservers() { + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + // MARK: Notification Observers + @objc func sessionWasInterrupted(notification: Notification) { + if let userInfoValue = notification.userInfo?[AVCaptureSessionInterruptionReasonKey] + as AnyObject?, + let reasonIntegerValue = userInfoValue.integerValue, + let reason = AVCaptureSession.InterruptionReason(rawValue: reasonIntegerValue) + { + os_log("Capture session was interrupted with reason: %s", type: .error, reason.rawValue) + + var canResumeManually = false + if reason == .videoDeviceInUseByAnotherClient { + canResumeManually = true + } else if reason == .videoDeviceNotAvailableWithMultipleForegroundApps { + canResumeManually = false + } + + delegate?.cameraFeedManager(self, sessionWasInterrupted: canResumeManually) + + } + } + + @objc func sessionInterruptionEnded(notification: Notification) { + delegate?.cameraFeedManagerDidEndSessionInterruption(self) + } + + @objc func sessionRuntimeErrorOccured(notification: Notification) { + guard let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError else { + return + } + + os_log("Capture session runtime error: %s", type: .error, error.localizedDescription) + + if error.code == .mediaServicesWereReset { + sessionQueue.async { + if self.isSessionRunning { + self.startSession() + } else { + DispatchQueue.main.async { + self.delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } + } + } else { + delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } +} + +/// AVCaptureVideoDataOutputSampleBufferDelegate +extension CameraFeedManager: AVCaptureVideoDataOutputSampleBufferDelegate { + /// This method delegates the CVPixelBuffer of the frame seen by the camera currently. + func captureOutput( + _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, + from connection: AVCaptureConnection + ) { + + // Converts the CMSampleBuffer to a CVPixelBuffer. + let pixelBuffer: CVPixelBuffer? = CMSampleBufferGetImageBuffer(sampleBuffer) + + guard let imagePixelBuffer = pixelBuffer else { + return + } + + // Delegates the pixel buffer to the ViewController. + delegate?.cameraFeedManager?(self, didOutput: imagePixelBuffer) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift new file mode 100644 index 0000000000000000000000000000000000000000..308c7ec54308af5c152ff6038670b26501a8e82c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift @@ -0,0 +1,39 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit +import AVFoundation + + /// The camera frame is displayed on this view. +class PreviewView: UIView { + var previewLayer: AVCaptureVideoPreviewLayer { + guard let layer = layer as? AVCaptureVideoPreviewLayer else { + fatalError("Layer expected is of type VideoPreviewLayer") + } + return layer + } + + var session: AVCaptureSession? { + get { + return previewLayer.session + } + set { + previewLayer.session = newValue + } + } + + override class var layerClass: AnyClass { + return AVCaptureVideoPreviewLayer.self + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift new file mode 100644 index 0000000000000000000000000000000000000000..c6be64af5678541ec09fc367b03c80155876f0ba --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift @@ -0,0 +1,21 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// Table cell for inference result in bottom view. +class InfoCell: UITableViewCell { + @IBOutlet weak var fieldNameLabel: UILabel! + @IBOutlet weak var infoLabel: UILabel! +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift new file mode 100644 index 0000000000000000000000000000000000000000..b0789ee58a1ea373d441f05333d8ce8914adadb7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift @@ -0,0 +1,25 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +enum Constants { + // MARK: - Constants related to the image processing + static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2) + static let rgbPixelChannels = 3 + static let maxRGBValue: Float32 = 255.0 + + // MARK: - Constants related to the model interperter + static let defaultThreadCount = 2 + static let defaultDelegate: Delegates = .CPU +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..031550ea0081963d18b5b83712854babaf7c0a34 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift @@ -0,0 +1,45 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CGSize { + /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio + /// of `self`. `self` image is resized to be inscribe to destination size and located in center of + /// destination. + /// + /// - Parameter toFitIn: destination size to be filled. + /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image. + func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform { + let sourceRatio = self.height / self.width + let destRatio = dest.height / dest.width + + // Calculates ratio `self` to `dest`. + var ratio: CGFloat + var x: CGFloat = 0 + var y: CGFloat = 0 + if sourceRatio > destRatio { + // Source size is taller than destination. Resized to fit in destination height, and find + // horizontal starting point to be centered. + ratio = dest.height / self.height + x = (dest.width - self.width * ratio) / 2 + } else { + ratio = dest.width / self.width + y = (dest.height - self.height * ratio) / 2 + } + return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..4899c76562a546c513736fbf4556629b08d2c929 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift @@ -0,0 +1,172 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CVPixelBuffer { + var size: CGSize { + return CGSize(width: CVPixelBufferGetWidth(self), height: CVPixelBufferGetHeight(self)) + } + + /// Returns a new `CVPixelBuffer` created by taking the self area and resizing it to the + /// specified target size. Aspect ratios of source image and destination image are expected to be + /// same. + /// + /// - Parameters: + /// - from: Source area of image to be cropped and resized. + /// - to: Size to scale the image to(i.e. image size used while training the model). + /// - Returns: The cropped and resized image of itself. + func resize(from source: CGRect, to size: CGSize) -> CVPixelBuffer? { + let rect = CGRect(origin: CGPoint(x: 0, y: 0), size: self.size) + guard rect.contains(source) else { + os_log("Resizing Error: source area is out of index", type: .error) + return nil + } + guard rect.size.width / rect.size.height - source.size.width / source.size.height < 1e-5 + else { + os_log( + "Resizing Error: source image ratio and destination image ratio is different", + type: .error) + return nil + } + + let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self) + let imageChannels = 4 + + CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) + defer { CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) } + + // Finds the address of the upper leftmost pixel of the source area. + guard + let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced( + by: Int(source.minY) * inputImageRowBytes + Int(source.minX) * imageChannels) + else { + return nil + } + + // Crops given area as vImage Buffer. + var croppedImage = vImage_Buffer( + data: inputBaseAddress, height: UInt(source.height), width: UInt(source.width), + rowBytes: inputImageRowBytes) + + let resultRowBytes = Int(size.width) * imageChannels + guard let resultAddress = malloc(Int(size.height) * resultRowBytes) else { + return nil + } + + // Allocates a vacant vImage buffer for resized image. + var resizedImage = vImage_Buffer( + data: resultAddress, + height: UInt(size.height), width: UInt(size.width), + rowBytes: resultRowBytes + ) + + // Performs the scale operation on cropped image and stores it in result image buffer. + guard vImageScale_ARGB8888(&croppedImage, &resizedImage, nil, vImage_Flags(0)) == kvImageNoError + else { + return nil + } + + let releaseCallBack: CVPixelBufferReleaseBytesCallback = { mutablePointer, pointer in + if let pointer = pointer { + free(UnsafeMutableRawPointer(mutating: pointer)) + } + } + + var result: CVPixelBuffer? + + // Converts the thumbnail vImage buffer to CVPixelBuffer + let conversionStatus = CVPixelBufferCreateWithBytes( + nil, + Int(size.width), Int(size.height), + CVPixelBufferGetPixelFormatType(self), + resultAddress, + resultRowBytes, + releaseCallBack, + nil, + nil, + &result + ) + + guard conversionStatus == kCVReturnSuccess else { + free(resultAddress) + return nil + } + + return result + } + + /// Returns the RGB `Data` representation of the given image buffer. + /// + /// - Parameters: + /// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than + /// floating point values). + /// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be + /// converted. + func rgbData( + isModelQuantized: Bool + ) -> Data? { + CVPixelBufferLockBaseAddress(self, .readOnly) + defer { CVPixelBufferUnlockBaseAddress(self, .readOnly) } + guard let sourceData = CVPixelBufferGetBaseAddress(self) else { + return nil + } + + let width = CVPixelBufferGetWidth(self) + let height = CVPixelBufferGetHeight(self) + let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(self) + let destinationBytesPerRow = Constants.rgbPixelChannels * width + + // Assign input image to `sourceBuffer` to convert it. + var sourceBuffer = vImage_Buffer( + data: sourceData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: sourceBytesPerRow) + + // Make `destinationBuffer` and `destinationData` for its data to be assigned. + guard let destinationData = malloc(height * destinationBytesPerRow) else { + os_log("Error: out of memory", type: .error) + return nil + } + defer { free(destinationData) } + var destinationBuffer = vImage_Buffer( + data: destinationData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: destinationBytesPerRow) + + // Convert image type. + switch CVPixelBufferGetPixelFormatType(self) { + case kCVPixelFormatType_32BGRA: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + case kCVPixelFormatType_32ARGB: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + default: + os_log("The type of this image is not supported.", type: .error) + return nil + } + + // Make `Data` with converted image. + let imageByteData = Data( + bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height) + + if isModelQuantized { return imageByteData } + + let imageBytes = [UInt8](imageByteData) + return Data(copyingBufferOf: imageBytes.map { Float($0) / Constants.maxRGBValue }) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..63f7ced786e2b550391c77af534d1d3c431522c6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift @@ -0,0 +1,75 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite + +// MARK: - Data +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } + + /// Convert a Data instance to Array representation. + func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic { + var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride) + _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) } + return array + } +} + +// MARK: - Wrappers +/// Struct for handling multidimension `Data` in flat `Array`. +struct FlatArray { + private var array: [Element] + var dimensions: [Int] + + init(tensor: Tensor) { + dimensions = tensor.shape.dimensions + array = tensor.data.toArray(type: Element.self) + } + + private func flatIndex(_ index: [Int]) -> Int { + guard index.count == dimensions.count else { + fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).") + } + + var result = 0 + for i in 0.. index[i] else { + fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])") + } + result = dimensions[i] * result + index[i] + } + return result + } + + subscript(_ index: Int...) -> Element { + get { + return array[flatIndex(index)] + } + set(newValue) { + array[flatIndex(index)] = newValue + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..4330d9b33f31010549802febc6f6f2bc9fd9b950 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist @@ -0,0 +1,42 @@ + + + + + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + NSCameraUsageDescription + This app will use camera to continuously estimate the depth map. + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift new file mode 100644 index 0000000000000000000000000000000000000000..144cfe1fa3a65af5adcb572237f2bf9718e570ae --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift @@ -0,0 +1,464 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite +import UIKit + +/// This class handles all data preprocessing and makes calls to run inference on a given frame +/// by invoking the `Interpreter`. It then formats the inferences obtained. +class ModelDataHandler { + // MARK: - Private Properties + + /// TensorFlow Lite `Interpreter` object for performing inference on a given model. + private var interpreter: Interpreter + + /// TensorFlow lite `Tensor` of model input and output. + private var inputTensor: Tensor + + //private var heatsTensor: Tensor + //private var offsetsTensor: Tensor + private var outputTensor: Tensor + // MARK: - Initialization + + /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is + /// successfully loaded from the app's main bundle. Default `threadCount` is 2. + init( + threadCount: Int = Constants.defaultThreadCount, + delegate: Delegates = Constants.defaultDelegate + ) throws { + // Construct the path to the model file. + guard + let modelPath = Bundle.main.path( + forResource: Model.file.name, + ofType: Model.file.extension + ) + else { + fatalError("Failed to load the model file with name: \(Model.file.name).") + } + + // Specify the options for the `Interpreter`. + var options = Interpreter.Options() + options.threadCount = threadCount + + // Specify the delegates for the `Interpreter`. + var delegates: [Delegate]? + switch delegate { + case .Metal: + delegates = [MetalDelegate()] + case .CoreML: + if let coreMLDelegate = CoreMLDelegate() { + delegates = [coreMLDelegate] + } else { + delegates = nil + } + default: + delegates = nil + } + + // Create the `Interpreter`. + interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) + + // Initialize input and output `Tensor`s. + // Allocate memory for the model's input `Tensor`s. + try interpreter.allocateTensors() + + // Get allocated input and output `Tensor`s. + inputTensor = try interpreter.input(at: 0) + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + /* + // Check if input and output `Tensor`s are in the expected formats. + guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { + fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") + } + + guard inputTensor.shape.dimensions[0] == Model.input.batchSize, + inputTensor.shape.dimensions[1] == Model.input.height, + inputTensor.shape.dimensions[2] == Model.input.width, + inputTensor.shape.dimensions[3] == Model.input.channelSize + else { + fatalError("Unexpected Model: input shape") + } + + + guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, + heatsTensor.shape.dimensions[1] == Model.output.height, + heatsTensor.shape.dimensions[2] == Model.output.width, + heatsTensor.shape.dimensions[3] == Model.output.keypointSize + else { + fatalError("Unexpected Model: heat tensor") + } + + guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, + offsetsTensor.shape.dimensions[1] == Model.output.height, + offsetsTensor.shape.dimensions[2] == Model.output.width, + offsetsTensor.shape.dimensions[3] == Model.output.offsetSize + else { + fatalError("Unexpected Model: offset tensor") + } + */ + + } + + /// Runs Midas model with given image with given source area to destination area. + /// + /// - Parameters: + /// - on: Input image to run the model. + /// - from: Range of input image to run the model. + /// - to: Size of view to render the result. + /// - Returns: Result of the inference and the times consumed in every steps. + func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) + //-> (Result, Times)? + //-> (FlatArray, Times)? + -> ([Float], Int, Int, Times)? + { + // Start times of each process. + let preprocessingStartTime: Date + let inferenceStartTime: Date + let postprocessingStartTime: Date + + // Processing times in miliseconds. + let preprocessingTime: TimeInterval + let inferenceTime: TimeInterval + let postprocessingTime: TimeInterval + + preprocessingStartTime = Date() + guard let data = preprocess(of: pixelbuffer, from: source) else { + os_log("Preprocessing failed", type: .error) + return nil + } + preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 + + inferenceStartTime = Date() + inference(from: data) + inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 + + postprocessingStartTime = Date() + //guard let result = postprocess(to: dest) else { + // os_log("Postprocessing failed", type: .error) + // return nil + //} + postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 + + + let results: [Float] + switch outputTensor.dataType { + case .uInt8: + guard let quantization = outputTensor.quantizationParameters else { + print("No results returned because the quantization values for the output tensor are nil.") + return nil + } + let quantizedResults = [UInt8](outputTensor.data) + results = quantizedResults.map { + quantization.scale * Float(Int($0) - quantization.zeroPoint) + } + case .float32: + results = [Float32](unsafeData: outputTensor.data) ?? [] + default: + print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") + return nil + } + + + let times = Times( + preprocessing: preprocessingTime, + inference: inferenceTime, + postprocessing: postprocessingTime) + + return (results, Model.input.width, Model.input.height, times) + } + + // MARK: - Private functions to run model + /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. + /// + /// - Parameters: + /// - of: Input image to crop and resize. + /// - from: Target area to be cropped and resized. + /// - Returns: The cropped and resized image. `nil` if it can not be processed. + private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { + let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) + assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) + + // Resize `targetSquare` of input image to `modelSize`. + let modelSize = CGSize(width: Model.input.width, height: Model.input.height) + guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) + else { + return nil + } + + // Remove the alpha component from the image buffer to get the initialized `Data`. + let byteCount = + Model.input.batchSize + * Model.input.height * Model.input.width + * Model.input.channelSize + guard + let inputData = thumbnail.rgbData( + isModelQuantized: Model.isQuantized + ) + else { + os_log("Failed to convert the image buffer to RGB data.", type: .error) + return nil + } + + return inputData + } + + + + /* + /// Postprocesses output `Tensor`s to `Result` with size of view to render the result. + /// + /// - Parameters: + /// - to: Size of view to be displaied. + /// - Returns: Postprocessed `Result`. `nil` if it can not be processed. + private func postprocess(to viewSize: CGSize) -> Result? { + // MARK: Formats output tensors + // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type + // `FlatArray`. + let heats = FlatArray(tensor: heatsTensor) + let offsets = FlatArray(tensor: offsetsTensor) + + // MARK: Find position of each key point + // Finds the (row, col) locations of where the keypoints are most likely to be. The highest + // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, + // `col`). + let keypointPositions = (0.. (Int, Int) in + var maxValue = heats[0, 0, 0, keypoint] + var maxRow = 0 + var maxCol = 0 + for row in 0.. maxValue { + maxValue = heats[0, row, col, keypoint] + maxRow = row + maxCol = col + } + } + } + return (maxRow, maxCol) + } + + // MARK: Calculates total confidence score + // Calculates total confidence score of each key position. + let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in + accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) + } + let totalScore = totalScoreSum / Float32(Model.output.keypointSize) + + // MARK: Calculate key point position on model input + // Calculates `KeyPoint` coordination model input image with `offsets` adjustment. + let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in + let (y, x) = elem + let yCoord = + Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) + + offsets[0, y, x, index] + let xCoord = + Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) + + offsets[0, y, x, index + Model.output.keypointSize] + return (y: yCoord, x: xCoord) + } + + // MARK: Transform key point position and make lines + // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. + var result = Result(dots: [], lines: [], score: totalScore) + var bodyPartToDotMap = [BodyPart: CGPoint]() + for (index, part) in BodyPart.allCases.enumerated() { + let position = CGPoint( + x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), + y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) + ) + bodyPartToDotMap[part] = position + result.dots.append(position) + } + + do { + try result.lines = BodyPart.lines.map { map throws -> Line in + guard let from = bodyPartToDotMap[map.from] else { + throw PostprocessError.missingBodyPart(of: map.from) + } + guard let to = bodyPartToDotMap[map.to] else { + throw PostprocessError.missingBodyPart(of: map.to) + } + return Line(from: from, to: to) + } + } catch PostprocessError.missingBodyPart(let missingPart) { + os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) + return nil + } catch { + os_log("Postprocessing error: %s", type: .error, error.localizedDescription) + return nil + } + + return result + } +*/ + + + + /// Run inference with given `Data` + /// + /// Parameter `from`: `Data` of input image to run model. + private func inference(from data: Data) { + // Copy the initialized `Data` to the input `Tensor`. + do { + try interpreter.copy(data, toInputAt: 0) + + // Run inference by invoking the `Interpreter`. + try interpreter.invoke() + + // Get the output `Tensor` to process the inference results. + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + + } catch let error { + os_log( + "Failed to invoke the interpreter with error: %s", type: .error, + error.localizedDescription) + return + } + } + + /// Returns value within [0,1]. + private func sigmoid(_ x: Float32) -> Float32 { + return (1.0 / (1.0 + exp(-x))) + } +} + +// MARK: - Data types for inference result +struct KeyPoint { + var bodyPart: BodyPart = BodyPart.NOSE + var position: CGPoint = CGPoint() + var score: Float = 0.0 +} + +struct Line { + let from: CGPoint + let to: CGPoint +} + +struct Times { + var preprocessing: Double + var inference: Double + var postprocessing: Double +} + +struct Result { + var dots: [CGPoint] + var lines: [Line] + var score: Float +} + +enum BodyPart: String, CaseIterable { + case NOSE = "nose" + case LEFT_EYE = "left eye" + case RIGHT_EYE = "right eye" + case LEFT_EAR = "left ear" + case RIGHT_EAR = "right ear" + case LEFT_SHOULDER = "left shoulder" + case RIGHT_SHOULDER = "right shoulder" + case LEFT_ELBOW = "left elbow" + case RIGHT_ELBOW = "right elbow" + case LEFT_WRIST = "left wrist" + case RIGHT_WRIST = "right wrist" + case LEFT_HIP = "left hip" + case RIGHT_HIP = "right hip" + case LEFT_KNEE = "left knee" + case RIGHT_KNEE = "right knee" + case LEFT_ANKLE = "left ankle" + case RIGHT_ANKLE = "right ankle" + + /// List of lines connecting each part. + static let lines = [ + (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), + (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), + (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), + (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), + (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), + (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), + ] +} + +// MARK: - Delegates Enum +enum Delegates: Int, CaseIterable { + case CPU + case Metal + case CoreML + + var description: String { + switch self { + case .CPU: + return "CPU" + case .Metal: + return "GPU" + case .CoreML: + return "NPU" + } + } +} + +// MARK: - Custom Errors +enum PostprocessError: Error { + case missingBodyPart(of: BodyPart) +} + +// MARK: - Information about the model file. +typealias FileInfo = (name: String, extension: String) + +enum Model { + static let file: FileInfo = ( + name: "model_opt", extension: "tflite" + ) + + static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) + static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) + static let isQuantized = false +} + + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + )) + } + #endif // swift(>=5.0) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..a04c79f554777863bd0dc8287bfd60704ce28bf2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..5f5623794bd35b9bb75efd7b7e249fd7357fdfbd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard @@ -0,0 +1,236 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift new file mode 100644 index 0000000000000000000000000000000000000000..fbb51b5a303412c0bbd158d76d025cf88fee6f8f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift @@ -0,0 +1,489 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + + +public struct PixelData { + var a: UInt8 + var r: UInt8 + var g: UInt8 + var b: UInt8 +} + +extension UIImage { + convenience init?(pixels: [PixelData], width: Int, height: Int) { + guard width > 0 && height > 0, pixels.count == width * height else { return nil } + var data = pixels + guard let providerRef = CGDataProvider(data: Data(bytes: &data, count: data.count * MemoryLayout.size) as CFData) + else { return nil } + guard let cgim = CGImage( + width: width, + height: height, + bitsPerComponent: 8, + bitsPerPixel: 32, + bytesPerRow: width * MemoryLayout.size, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue), + provider: providerRef, + decode: nil, + shouldInterpolate: false, + intent: .defaultIntent) + else { return nil } + self.init(cgImage: cgim) + } +} + + +class ViewController: UIViewController { + // MARK: Storyboards Connections + @IBOutlet weak var previewView: PreviewView! + + //@IBOutlet weak var overlayView: OverlayView! + @IBOutlet weak var overlayView: UIImageView! + + private var imageView : UIImageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)) + + private var imageViewInitialized: Bool = false + + @IBOutlet weak var resumeButton: UIButton! + @IBOutlet weak var cameraUnavailableLabel: UILabel! + + @IBOutlet weak var tableView: UITableView! + + @IBOutlet weak var threadCountLabel: UILabel! + @IBOutlet weak var threadCountStepper: UIStepper! + + @IBOutlet weak var delegatesControl: UISegmentedControl! + + // MARK: ModelDataHandler traits + var threadCount: Int = Constants.defaultThreadCount + var delegate: Delegates = Constants.defaultDelegate + + // MARK: Result Variables + // Inferenced data to render. + private var inferencedData: InferencedData? + + // Minimum score to render the result. + private let minimumScore: Float = 0.5 + + private var avg_latency: Double = 0.0 + + // Relative location of `overlayView` to `previewView`. + private var overlayViewFrame: CGRect? + + private var previewViewFrame: CGRect? + + // MARK: Controllers that manage functionality + // Handles all the camera related functionality + private lazy var cameraCapture = CameraFeedManager(previewView: previewView) + + // Handles all data preprocessing and makes calls to run inference. + private var modelDataHandler: ModelDataHandler? + + // MARK: View Handling Methods + override func viewDidLoad() { + super.viewDidLoad() + + do { + modelDataHandler = try ModelDataHandler() + } catch let error { + fatalError(error.localizedDescription) + } + + cameraCapture.delegate = self + tableView.delegate = self + tableView.dataSource = self + + // MARK: UI Initialization + // Setup thread count stepper with white color. + // https://forums.developer.apple.com/thread/121495 + threadCountStepper.setDecrementImage( + threadCountStepper.decrementImage(for: .normal), for: .normal) + threadCountStepper.setIncrementImage( + threadCountStepper.incrementImage(for: .normal), for: .normal) + // Setup initial stepper value and its label. + threadCountStepper.value = Double(Constants.defaultThreadCount) + threadCountLabel.text = Constants.defaultThreadCount.description + + // Setup segmented controller's color. + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.lightGray], + for: .normal) + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.black], + for: .selected) + // Remove existing segments to initialize it with `Delegates` entries. + delegatesControl.removeAllSegments() + Delegates.allCases.forEach { delegate in + delegatesControl.insertSegment( + withTitle: delegate.description, + at: delegate.rawValue, + animated: false) + } + delegatesControl.selectedSegmentIndex = 0 + } + + override func viewWillAppear(_ animated: Bool) { + super.viewWillAppear(animated) + + cameraCapture.checkCameraConfigurationAndStartSession() + } + + override func viewWillDisappear(_ animated: Bool) { + cameraCapture.stopSession() + } + + override func viewDidLayoutSubviews() { + overlayViewFrame = overlayView.frame + previewViewFrame = previewView.frame + } + + // MARK: Button Actions + @IBAction func didChangeThreadCount(_ sender: UIStepper) { + let changedCount = Int(sender.value) + if threadCountLabel.text == changedCount.description { + return + } + + do { + modelDataHandler = try ModelDataHandler(threadCount: changedCount, delegate: delegate) + } catch let error { + fatalError(error.localizedDescription) + } + threadCount = changedCount + threadCountLabel.text = changedCount.description + os_log("Thread count is changed to: %d", threadCount) + } + + @IBAction func didChangeDelegate(_ sender: UISegmentedControl) { + guard let changedDelegate = Delegates(rawValue: delegatesControl.selectedSegmentIndex) else { + fatalError("Unexpected value from delegates segemented controller.") + } + do { + modelDataHandler = try ModelDataHandler(threadCount: threadCount, delegate: changedDelegate) + } catch let error { + fatalError(error.localizedDescription) + } + delegate = changedDelegate + os_log("Delegate is changed to: %s", delegate.description) + } + + @IBAction func didTapResumeButton(_ sender: Any) { + cameraCapture.resumeInterruptedSession { complete in + + if complete { + self.resumeButton.isHidden = true + self.cameraUnavailableLabel.isHidden = true + } else { + self.presentUnableToResumeSessionAlert() + } + } + } + + func presentUnableToResumeSessionAlert() { + let alert = UIAlertController( + title: "Unable to Resume Session", + message: "There was an error while attempting to resume session.", + preferredStyle: .alert + ) + alert.addAction(UIAlertAction(title: "OK", style: .default, handler: nil)) + + self.present(alert, animated: true) + } +} + +// MARK: - CameraFeedManagerDelegate Methods +extension ViewController: CameraFeedManagerDelegate { + func cameraFeedManager(_ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer) { + runModel(on: pixelBuffer) + } + + // MARK: Session Handling Alerts + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) { + // Handles session run time error by updating the UI and providing a button if session can be + // manually resumed. + self.resumeButton.isHidden = false + } + + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) { + // Updates the UI when session is interupted. + if canResumeManually { + self.resumeButton.isHidden = false + } else { + self.cameraUnavailableLabel.isHidden = false + } + } + + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) { + // Updates UI once session interruption has ended. + self.cameraUnavailableLabel.isHidden = true + self.resumeButton.isHidden = true + } + + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Confirguration Failed", message: "Configuration of camera has failed.", + preferredStyle: .alert) + let okAction = UIAlertAction(title: "OK", style: .cancel, handler: nil) + alertController.addAction(okAction) + + present(alertController, animated: true, completion: nil) + } + + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Camera Permissions Denied", + message: + "Camera permissions have been denied for this app. You can change this by going to Settings", + preferredStyle: .alert) + + let cancelAction = UIAlertAction(title: "Cancel", style: .cancel, handler: nil) + let settingsAction = UIAlertAction(title: "Settings", style: .default) { action in + if let url = URL.init(string: UIApplication.openSettingsURLString) { + UIApplication.shared.open(url, options: [:], completionHandler: nil) + } + } + + alertController.addAction(cancelAction) + alertController.addAction(settingsAction) + + present(alertController, animated: true, completion: nil) + } + + @objc func runModel(on pixelBuffer: CVPixelBuffer) { + guard let overlayViewFrame = overlayViewFrame, let previewViewFrame = previewViewFrame + else { + return + } + // To put `overlayView` area as model input, transform `overlayViewFrame` following transform + // from `previewView` to `pixelBuffer`. `previewView` area is transformed to fit in + // `pixelBuffer`, because `pixelBuffer` as a camera output is resized to fill `previewView`. + // https://developer.apple.com/documentation/avfoundation/avlayervideogravity/1385607-resizeaspectfill + let modelInputRange = overlayViewFrame.applying( + previewViewFrame.size.transformKeepAspect(toFitIn: pixelBuffer.size)) + + // Run Midas model. + guard + let (result, width, height, times) = self.modelDataHandler?.runMidas( + on: pixelBuffer, + from: modelInputRange, + to: overlayViewFrame.size) + else { + os_log("Cannot get inference result.", type: .error) + return + } + + if avg_latency == 0 { + avg_latency = times.inference + } else { + avg_latency = times.inference*0.1 + avg_latency*0.9 + } + + // Udpate `inferencedData` to render data in `tableView`. + inferencedData = InferencedData(score: Float(avg_latency), times: times) + + //let height = 256 + //let width = 256 + + let outputs = result + let outputs_size = width * height; + + var multiplier : Float = 1.0; + + let max_val : Float = outputs.max() ?? 0 + let min_val : Float = outputs.min() ?? 0 + + if((max_val - min_val) > 0) { + multiplier = 255 / (max_val - min_val); + } + + // Draw result. + DispatchQueue.main.async { + self.tableView.reloadData() + + var pixels: [PixelData] = .init(repeating: .init(a: 255, r: 0, g: 0, b: 0), count: width * height) + + for i in pixels.indices { + //if(i < 1000) + //{ + let val = UInt8((outputs[i] - min_val) * multiplier) + + pixels[i].r = val + pixels[i].g = val + pixels[i].b = val + //} + } + + + /* + pixels[i].a = 255 + pixels[i].r = .random(in: 0...255) + pixels[i].g = .random(in: 0...255) + pixels[i].b = .random(in: 0...255) + } + */ + + DispatchQueue.main.async { + let image = UIImage(pixels: pixels, width: width, height: height) + + self.imageView.image = image + + if (self.imageViewInitialized == false) { + self.imageViewInitialized = true + self.overlayView.addSubview(self.imageView) + self.overlayView.setNeedsDisplay() + } + } + + /* + let image = UIImage(pixels: pixels, width: width, height: height) + + var imageView : UIImageView + imageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)); + imageView.image = image + self.overlayView.addSubview(imageView) + self.overlayView.setNeedsDisplay() + */ + } + } +/* + func drawResult(of result: Result) { + self.overlayView.dots = result.dots + self.overlayView.lines = result.lines + self.overlayView.setNeedsDisplay() + } + + func clearResult() { + self.overlayView.clear() + self.overlayView.setNeedsDisplay() + } + */ + +} + + +// MARK: - TableViewDelegate, TableViewDataSource Methods +extension ViewController: UITableViewDelegate, UITableViewDataSource { + func numberOfSections(in tableView: UITableView) -> Int { + return InferenceSections.allCases.count + } + + func tableView(_ tableView: UITableView, numberOfRowsInSection section: Int) -> Int { + guard let section = InferenceSections(rawValue: section) else { + return 0 + } + + return section.subcaseCount + } + + func tableView(_ tableView: UITableView, cellForRowAt indexPath: IndexPath) -> UITableViewCell { + let cell = tableView.dequeueReusableCell(withIdentifier: "InfoCell") as! InfoCell + guard let section = InferenceSections(rawValue: indexPath.section) else { + return cell + } + guard let data = inferencedData else { return cell } + + var fieldName: String + var info: String + + switch section { + case .Score: + fieldName = section.description + info = String(format: "%.3f", data.score) + case .Time: + guard let row = ProcessingTimes(rawValue: indexPath.row) else { + return cell + } + var time: Double + switch row { + case .InferenceTime: + time = data.times.inference + } + fieldName = row.description + info = String(format: "%.2fms", time) + } + + cell.fieldNameLabel.text = fieldName + cell.infoLabel.text = info + + return cell + } + + func tableView(_ tableView: UITableView, heightForRowAt indexPath: IndexPath) -> CGFloat { + guard let section = InferenceSections(rawValue: indexPath.section) else { + return 0 + } + + var height = Traits.normalCellHeight + if indexPath.row == section.subcaseCount - 1 { + height = Traits.separatorCellHeight + Traits.bottomSpacing + } + return height + } + +} + +// MARK: - Private enums +/// UI coinstraint values +fileprivate enum Traits { + static let normalCellHeight: CGFloat = 35.0 + static let separatorCellHeight: CGFloat = 25.0 + static let bottomSpacing: CGFloat = 30.0 +} + +fileprivate struct InferencedData { + var score: Float + var times: Times +} + +/// Type of sections in Info Cell +fileprivate enum InferenceSections: Int, CaseIterable { + case Score + case Time + + var description: String { + switch self { + case .Score: + return "Average" + case .Time: + return "Processing Time" + } + } + + var subcaseCount: Int { + switch self { + case .Score: + return 1 + case .Time: + return ProcessingTimes.allCases.count + } + } +} + +/// Type of processing times in Time section in Info Cell +fileprivate enum ProcessingTimes: Int, CaseIterable { + case InferenceTime + + var description: String { + switch self { + case .InferenceTime: + return "Inference Time" + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift new file mode 100644 index 0000000000000000000000000000000000000000..3b53910b57563b6a195fd53321fa2a24ebaf3d3f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift @@ -0,0 +1,63 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// UIView for rendering inference output. +class OverlayView: UIView { + + var dots = [CGPoint]() + var lines = [Line]() + + override func draw(_ rect: CGRect) { + for dot in dots { + drawDot(of: dot) + } + for line in lines { + drawLine(of: line) + } + } + + func drawDot(of dot: CGPoint) { + let dotRect = CGRect( + x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2, + width: Traits.dot.radius, height: Traits.dot.radius) + let dotPath = UIBezierPath(ovalIn: dotRect) + + Traits.dot.color.setFill() + dotPath.fill() + } + + func drawLine(of line: Line) { + let linePath = UIBezierPath() + linePath.move(to: CGPoint(x: line.from.x, y: line.from.y)) + linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y)) + linePath.close() + + linePath.lineWidth = Traits.line.width + Traits.line.color.setStroke() + + linePath.stroke() + } + + func clear() { + self.dots = [] + self.lines = [] + } +} + +private enum Traits { + static let dot = (radius: CGFloat(5), color: UIColor.orange) + static let line = (width: CGFloat(1.0), color: UIColor.orange) +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..5e9461fc96dbbe3c22ca6bbf2bfd7df3981b9462 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile @@ -0,0 +1,12 @@ +# Uncomment the next line to define a global platform for your project + platform :ios, '12.0' + +target 'Midas' do + # Comment the next line if you're not using Swift and don't want to use dynamic frameworks + use_frameworks! + + # Pods for Midas + pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly' +end diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b8eb29feaa21e67814b035dbd5c5fb2c62a4151 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md @@ -0,0 +1,105 @@ +# Tensorflow Lite MiDaS iOS Example + +### Requirements + +- XCode 11.0 or above +- iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339) +- TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift. + +Paper: https://arxiv.org/abs/1907.01341 + +> Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +> René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +### Install TensorFlow + +Set default python version to python3: + +``` +echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv +echo 'alias python=python3' >> ~/.zshenv +echo 'alias pip=pip3' >> ~/.zshenv +``` + +Install TensorFlow + +```shell +pip install tensorflow +``` + +### Install TensorFlowLiteSwift via Cocoapods + +Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9 + +Install: brew, ruby, cocoapods + +``` +ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" +brew install mc rbenv ruby-build +sudo gem install cocoapods +``` + + +The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project: + +```ruby +pod install +``` + +Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera. + +### Model + +The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model` + + +To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed): + +```python +saved_model_export_dir = "./saved_model" +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir) +tflite_model = converter.convert() +open(model_tflite_name, "wb").write("model.tflite") +``` + +### Setup XCode + +* Open directory `.xcworkspace` from the XCode + +* Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique) + +* select your Developer Team (your should be signed-in by using your AppleID) + +* Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone) + +* Click in the XCode: Product -> Run + +* On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development + +---- + +Original repository: https://github.com/isl-org/MiDaS + + +### Examples: + +| ![photo_2020-09-27_17-43-20](https://user-images.githubusercontent.com/4096485/94367804-9610de80-00e9-11eb-8a23-8b32a6f52d41.jpg) | ![photo_2020-09-27_17-49-22](https://user-images.githubusercontent.com/4096485/94367974-7201cd00-00ea-11eb-8e0a-68eb9ea10f63.jpg) | ![photo_2020-09-27_17-52-30](https://user-images.githubusercontent.com/4096485/94367976-729a6380-00ea-11eb-8ce0-39d3e26dd550.jpg) | ![photo_2020-09-27_17-43-21](https://user-images.githubusercontent.com/4096485/94367807-97420b80-00e9-11eb-9dcd-848ad9e89e03.jpg) | +|---|---|---|---| + +## LICENSE + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..d737b39d966278f5c6bc29802526ab86f8473de4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Download TF Lite model from the internet if it does not exist. + +TFLITE_MODEL="model_opt.tflite" +TFLITE_FILE="Midas/Model/${TFLITE_MODEL}" +MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}" + +if test -f "${TFLITE_FILE}"; then + echo "INFO: TF Lite model already exists. Skip downloading and use the local model." +else + curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}" + echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}." +fi + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1d43c2606767798ee46b34292e0483197424ec23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md @@ -0,0 +1,131 @@ +# MiDaS for ROS1 by using LibTorch in C++ + +### Requirements + +- Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch +- ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04 +- C++11 +- LibTorch >= 1.6 + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. + +* input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape +* output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1 + +### Install Dependecies + +* install ROS Melodic for Ubuntu 17.10 / 18.04: +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh +./install_ros_melodic_ubuntu_17_18.sh +``` + +or Noetic for Ubuntu 20.04: + +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh +./install_ros_noetic_ubuntu_20.sh +``` + + +* install LibTorch 1.7 with CUDA 11.0: + +On **Jetson (ARM)**: +```bash +wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl +sudo apt-get install python3-pip libopenblas-base libopenmpi-dev +pip3 install Cython +pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl +``` +Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source + +On **Linux (x86_64)**: +```bash +cd ~/ +wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip +unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip +``` + +* create symlink for OpenCV: + +```bash +sudo ln -s /usr/include/opencv4 /usr/include/opencv +``` + +* download and install MiDaS: + +```bash +source ~/.bashrc +cd ~/ +mkdir catkin_ws +cd catkin_ws +git clone https://github.com/isl-org/MiDaS +mkdir src +cp -r MiDaS/ros/* src + +chmod +x src/additions/*.sh +chmod +x src/*.sh +chmod +x src/midas_cpp/scripts/*.py +cp src/additions/do_catkin_make.sh ./do_catkin_make.sh +./do_catkin_make.sh +./src/additions/downloads.sh +``` + +### Usage + +* run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + +#### Test + +* Test - capture video and show result in the window: + * place any `test.mp4` video file to the directory `~/catkin_ws/src/` + * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds + + (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` ) + +## Mobile version of MiDaS - Monocular Depth Estimation + +### Accuracy + +* MiDaS v2 small - ResNet50 default-decoder 384x384 +* MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on nVidia GPU + +Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better): + +| Model | Jetson Nano, FPS | RTX 2080Ti, FPS | +|---|---|---| +| MiDaS v2 small 384x384 | 1.6 | 117 | +| MiDaS v2.1 small 256x256 | 8.1 | 232 | +| SpeedUp, X times | **5x** | **2x** | + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d416fc00282aab146326bbba12a9274e1ba29b8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh @@ -0,0 +1,5 @@ +mkdir src +catkin_make +source devel/setup.bash +echo $ROS_PACKAGE_PATH +chmod +x ./devel/setup.bash diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c967d4e2dc7997da26399a063b5a54ecc314eb1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh @@ -0,0 +1,5 @@ +mkdir ~/.ros +wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt +cp ./model-small-traced.pt ~/.ros/model-small-traced.pt + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh new file mode 100644 index 0000000000000000000000000000000000000000..b868112631e9d9bc7bccb601407dfc857b8a99d5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh @@ -0,0 +1,34 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 +sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-melodic-desktop-full + +printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python-rosinstall +sudo apt-get install python-catkin-tools +sudo apt-get install python-rospy +sudo apt-get install python-rosdep +sudo apt-get install python-roscd +sudo apt-get install python-pip \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh new file mode 100644 index 0000000000000000000000000000000000000000..d73ea1a3d92359819167d735a92d2a650b9bc245 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh @@ -0,0 +1,33 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-noetic-desktop-full + +printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python3-rosinstall +sudo apt-get install python3-catkin-tools +sudo apt-get install python3-rospy +sudo apt-get install python3-rosdep +sudo apt-get install python3-roscd +sudo apt-get install python3-pip \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ef6073a9c9ce40744e1c81d557c1c68255b95e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh @@ -0,0 +1,16 @@ +cd ~/catkin_ws/src +catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport +cd ~/catkin_ws +catkin_make + +chmod +x ~/catkin_ws/devel/setup.bash +printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc +source ~/catkin_ws/devel/setup.bash + + +sudo rosdep init +rosdep update +#rospack depends1 midas_cpp +roscd midas_cpp +#cat package.xml +#rospack depends midas_cpp \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a0d1583fffdc49216c625dfd07af2ae3b01a7a0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh @@ -0,0 +1,2 @@ +source ~/catkin_ws/devel/setup.bash +roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true" \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..885341691d217f9c4c8fcb1e4ff568d87788c7b8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt @@ -0,0 +1,189 @@ +cmake_minimum_required(VERSION 3.0.2) +project(midas_cpp) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs +) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + +list(APPEND CMAKE_PREFIX_PATH "~/libtorch") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python3.6/dist-packages/torch/lib") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python2.7/dist-packages/torch/lib") + +if(NOT EXISTS "~/libtorch") + if (EXISTS "/usr/local/lib/python3.6/dist-packages/torch") + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python3.6/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python3.6/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python3.6/dist-packages/torch) + + elseif (EXISTS "/usr/local/lib/python2.7/dist-packages/torch") + + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python2.7/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python2.7/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python2.7/dist-packages/torch) + endif() +endif() + + + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +include_directories( ${OpenCV_INCLUDE_DIRS} ) + +add_executable(midas_cpp src/main.cpp) +target_link_libraries(midas_cpp "${TORCH_LIBRARIES}" "${OpenCV_LIBS} ${catkin_LIBRARIES}") +set_property(TARGET midas_cpp PROPERTY CXX_STANDARD 14) + + + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES midas_cpp +# CATKIN_DEPENDS cv_bridge image_transport roscpp sensor_msgs std_msgs +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/midas_cpp.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/midas_cpp_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_midas_cpp.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +add_custom_command( + TARGET midas_cpp POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/midas_cpp + ${CMAKE_SOURCE_DIR}/midas_cpp +) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch new file mode 100644 index 0000000000000000000000000000000000000000..88e86f42f668e76ad4976ec6794a8cb0f20cac65 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch new file mode 100644 index 0000000000000000000000000000000000000000..8817a4f4933c56986fe0edc0886b2fded3d3406d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml new file mode 100644 index 0000000000000000000000000000000000000000..9cac90eba75409bd170f73531c54c83c52ff047a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml @@ -0,0 +1,77 @@ + + + midas_cpp + 0.1.0 + The midas_cpp package + + Alexey Bochkovskiy + MIT + https://github.com/isl-org/MiDaS/tree/master/ros + + + + + + + TODO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py new file mode 100644 index 0000000000000000000000000000000000000000..6927ea7a83ac9309e5f883ee974a5dcfa8a2aa3b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py new file mode 100644 index 0000000000000000000000000000000000000000..20e235f6958d644b89383752ab18e9e2275f55e5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("image_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener_original: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show_orig", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener_original', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py new file mode 100644 index 0000000000000000000000000000000000000000..8219cc8632484a2efd02984347c615efad6b78b2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + + +def talker(): + rospy.init_node('talker', anonymous=True) + + use_camera = rospy.get_param('~use_camera', False) + input_video_file = rospy.get_param('~input_video_file','test.mp4') + # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}") + + # rospy.loginfo("Talker: Trying to open a video stream") + if use_camera == True: + cap = cv2.VideoCapture(0) + else: + cap = cv2.VideoCapture(input_video_file) + + pub = rospy.Publisher('image_topic', Image, queue_size=1) + rate = rospy.Rate(30) # 30hz + bridge = CvBridge() + + while not rospy.is_shutdown(): + ret, cv_image = cap.read() + if ret==False: + print("Talker: Video is over") + rospy.loginfo("Video is over") + return + + try: + image = bridge.cv2_to_imgmsg(cv_image, "bgr8") + except CvBridgeError as e: + rospy.logerr("Talker: cv2image conversion failed: ", e) + print(e) + continue + + rospy.loginfo("Talker: Publishing frame") + pub.publish(image) + rate.sleep() + +if __name__ == '__main__': + try: + talker() + except rospy.ROSInterruptException: + pass diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4fc72c6955f66af71c9cb1fc7a7b1f643129685 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include + +#include + +#include // One-stop header. + +#include +#include +#include +#include + +#include +#include + +// includes for OpenCV >= 3.x +#ifndef CV_VERSION_EPOCH +#include +#include +#include +#endif + +// OpenCV includes for OpenCV 2.x +#ifdef CV_VERSION_EPOCH +#include +#include +#include +#include +#endif + +static const std::string OPENCV_WINDOW = "Image window"; + +class Midas +{ + ros::NodeHandle nh_; + image_transport::ImageTransport it_; + image_transport::Subscriber image_sub_; + image_transport::Publisher image_pub_; + + torch::jit::script::Module module; + torch::Device device; + + auto ToTensor(cv::Mat img, bool show_output = false, bool unsqueeze = false, int unsqueeze_dim = 0) + { + //std::cout << "image shape: " << img.size() << std::endl; + at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte); + + if (unsqueeze) + { + tensor_image.unsqueeze_(unsqueeze_dim); + //std::cout << "tensors new shape: " << tensor_image.sizes() << std::endl; + } + + if (show_output) + { + std::cout << tensor_image.slice(2, 0, 1) << std::endl; + } + //std::cout << "tenor shape: " << tensor_image.sizes() << std::endl; + return tensor_image; + } + + auto ToInput(at::Tensor tensor_image) + { + // Create a vector of inputs. + return std::vector{tensor_image}; + } + + auto ToCvImage(at::Tensor tensor, int cv_type = CV_8UC3) + { + int width = tensor.sizes()[0]; + int height = tensor.sizes()[1]; + try + { + cv::Mat output_mat; + if (cv_type == CV_8UC4 || cv_type == CV_8UC3 || cv_type == CV_8UC2 || cv_type == CV_8UC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_32FC4 || cv_type == CV_32FC3 || cv_type == CV_32FC2 || cv_type == CV_32FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_64FC4 || cv_type == CV_64FC3 || cv_type == CV_64FC2 || cv_type == CV_64FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + + //show_image(output_mat, "converted image from tensor"); + return output_mat.clone(); + } + catch (const c10::Error& e) + { + std::cout << "an error has occured : " << e.msg() << std::endl; + } + return cv::Mat(height, width, CV_8UC3); + } + + std::string input_topic, output_topic, model_name; + bool out_orig_size; + int net_width, net_height; + torch::NoGradGuard guard; + at::Tensor mean, std; + at::Tensor output, tensor; + +public: + Midas() + : nh_(), it_(nh_), device(torch::Device(torch::kCPU)) + { + ros::param::param("~input_topic", input_topic, "image_topic"); + ros::param::param("~output_topic", output_topic, "midas_topic"); + ros::param::param("~model_name", model_name, "model-small-traced.pt"); + ros::param::param("~out_orig_size", out_orig_size, true); + ros::param::param("~net_width", net_width, 256); + ros::param::param("~net_height", net_height, 256); + + std::cout << ", input_topic = " << input_topic << + ", output_topic = " << output_topic << + ", model_name = " << model_name << + ", out_orig_size = " << out_orig_size << + ", net_width = " << net_width << + ", net_height = " << net_height << + std::endl; + + // Subscrive to input video feed and publish output video feed + image_sub_ = it_.subscribe(input_topic, 1, &Midas::imageCb, this); + image_pub_ = it_.advertise(output_topic, 1); + + std::cout << "Try to load torchscript model \n"; + + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load(model_name); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + exit(0); + } + + std::cout << "ok\n"; + + try { + module.eval(); + torch::jit::getProfilingMode() = false; + torch::jit::setGraphExecutorOptimize(true); + + mean = torch::tensor({ 0.485, 0.456, 0.406 }); + std = torch::tensor({ 0.229, 0.224, 0.225 }); + + if (torch::hasCUDA()) { + std::cout << "cuda is available" << std::endl; + at::globalContext().setBenchmarkCuDNN(true); + device = torch::Device(torch::kCUDA); + module.to(device); + mean = mean.to(device); + std = std.to(device); + } + } + catch (const c10::Error& e) + { + std::cerr << " module initialization: " << e.msg() << std::endl; + } + } + + ~Midas() + { + } + + void imageCb(const sensor_msgs::ImageConstPtr& msg) + { + cv_bridge::CvImagePtr cv_ptr; + try + { + // sensor_msgs::Image to cv::Mat + cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8); + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // pre-processing + auto tensor_cpu = ToTensor(cv_ptr->image); // OpenCV-image -> Libtorch-tensor + + try { + tensor = tensor_cpu.to(device); // move to device (CPU or GPU) + + tensor = tensor.toType(c10::kFloat); + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor = tensor.unsqueeze(0); + tensor = at::upsample_bilinear2d(tensor, { net_height, net_width }, true); // resize + tensor = tensor.squeeze(0); + tensor = tensor.permute({ 1, 2, 0 }); // CHW -> HWC + + tensor = tensor.div(255).sub(mean).div(std); // normalization + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor.unsqueeze_(0); // CHW -> NCHW + } + catch (const c10::Error& e) + { + std::cerr << " pre-processing exception: " << e.msg() << std::endl; + return; + } + + auto input_to_net = ToInput(tensor); // input to the network + + // inference + output; + try { + output = module.forward(input_to_net).toTensor(); // run inference + } + catch (const c10::Error& e) + { + std::cerr << " module.forward() exception: " << e.msg() << std::endl; + return; + } + + output = output.detach().to(torch::kF32); + + // move to CPU temporary + at::Tensor output_tmp = output; + output_tmp = output_tmp.to(torch::kCPU); + + // normalization + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::min(); + + for (int i = 0; i < net_width * net_height; ++i) { + float val = output_tmp.data_ptr()[i]; + if (min_val > val) min_val = val; + if (max_val < val) max_val = val; + } + float range_val = max_val - min_val; + + output = output.sub(min_val).div(range_val).mul(255.0F).clamp(0, 255).to(torch::kF32); // .to(torch::kU8); + + // resize to the original size if required + if (out_orig_size) { + try { + output = at::upsample_bilinear2d(output.unsqueeze(0), { cv_ptr->image.size().height, cv_ptr->image.size().width }, true); + output = output.squeeze(0); + } + catch (const c10::Error& e) + { + std::cout << " upsample_bilinear2d() exception: " << e.msg() << std::endl; + return; + } + } + output = output.permute({ 1, 2, 0 }).to(torch::kCPU); + + int cv_type = CV_32FC1; // CV_8UC1; + auto cv_img = ToCvImage(output, cv_type); + + sensor_msgs::Image img_msg; + + try { + // cv::Mat -> sensor_msgs::Image + std_msgs::Header header; // empty header + header.seq = 0; // user defined counter + header.stamp = ros::Time::now();// time + //cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::MONO8, cv_img); + cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::TYPE_32FC1, cv_img); + + img_bridge.toImageMsg(img_msg); // cv_bridge -> sensor_msgs::Image + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // Output modified video stream + image_pub_.publish(img_msg); + } +}; + +int main(int argc, char** argv) +{ + ros::init(argc, argv, "midas", ros::init_options::AnonymousName); + Midas ic; + ros::spin(); + return 0; +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..a997c4261072d0d627598fe06a723fcc7522d347 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh @@ -0,0 +1,16 @@ +# place any test.mp4 file near with this file + +# roscore +# rosnode kill -a + +source ~/catkin_ws/devel/setup.bash + +roscore & +P1=$! +rosrun midas_cpp talker.py & +P2=$! +rosrun midas_cpp listener_original.py & +P3=$! +rosrun midas_cpp listener.py & +P4=$! +wait $P1 $P2 $P3 $P4 \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5696ef0547af093713ea416d18edd77d11879d0a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py @@ -0,0 +1,277 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import torch +import utils +import cv2 +import argparse +import time + +import numpy as np + +from imutils.video import VideoStream +from midas.model_loader import default_models, load_model + +first_execution = True +def process(device, model, model_type, image, input_size, target_size, optimize, use_camera): + """ + Run the inference and interpolate. + + Args: + device (torch.device): the torch device used + model: the model used for inference + model_type: the type of the model + image: the image fed into the neural network + input_size: the size (width, height) of the neural network input (for OpenVINO) + target_size: the size (width, height) the neural network output is interpolated to + optimize: optimize the model to half-floats on CUDA? + use_camera: is the camera used? + + Returns: + the prediction + """ + global first_execution + + if "openvino" in model_type: + if first_execution or not use_camera: + print(f" Input resized to {input_size[0]}x{input_size[1]} before entering the encoder") + first_execution = False + + sample = [np.reshape(image, (1, 3, *input_size))] + prediction = model(sample)[model.output(0)][0] + prediction = cv2.resize(prediction, dsize=target_size, + interpolation=cv2.INTER_CUBIC) + else: + sample = torch.from_numpy(image).to(device).unsqueeze(0) + + if optimize and device == torch.device("cuda"): + if first_execution: + print(" Optimization to half-floats activated. Use with caution, because models like Swin require\n" + " float precision to work properly and may yield non-finite depth values to some extent for\n" + " half-floats.") + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + + if first_execution or not use_camera: + height, width = sample.shape[2:] + print(f" Input resized to {width}x{height} before entering the encoder") + first_execution = False + + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=target_size[::-1], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + return prediction + + +def create_side_by_side(image, depth, grayscale): + """ + Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map + for better visibility. + + Args: + image: the RGB image + depth: the depth map + grayscale: use a grayscale colormap? + + Returns: + the image and depth map place side by side + """ + depth_min = depth.min() + depth_max = depth.max() + normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min) + normalized_depth *= 3 + + right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3 + if not grayscale: + right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO) + + if image is None: + return right_side + else: + return np.concatenate((image, right_side), axis=1) + + +def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None, + square=False, grayscale=False): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + model_type (str): the model type + optimize (bool): optimize the model to half-floats on CUDA? + side (bool): RGB and depth side by side in output images? + height (int): inference encoder image height + square (bool): resize to a square resolution? + grayscale (bool): use a grayscale colormap? + """ + print("Initialize") + + # select device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device: %s" % device) + + model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square) + + # get input + if input_path is not None: + image_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(image_names) + else: + print("No input path specified. Grabbing images from camera.") + + # create output folder + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + + print("Start processing") + + if input_path is not None: + if output_path is None: + print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.") + for index, image_name in enumerate(image_names): + + print(" Processing {} ({}/{})".format(image_name, index + 1, num_images)) + + # input + original_image_rgb = utils.read_image(image_name) # in [0, 1] + image = transform({"image": original_image_rgb})["image"] + + # compute + with torch.no_grad(): + prediction = process(device, model, model_type, image, (net_w, net_h), original_image_rgb.shape[1::-1], + optimize, False) + + # output + if output_path is not None: + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(image_name))[0] + '-' + model_type + ) + if not side: + utils.write_depth(filename, prediction, grayscale, bits=2) + else: + original_image_bgr = np.flip(original_image_rgb, 2) + content = create_side_by_side(original_image_bgr*255, prediction, grayscale) + cv2.imwrite(filename + ".png", content) + utils.write_pfm(filename + ".pfm", prediction.astype(np.float32)) + + else: + with torch.no_grad(): + fps = 1 + video = VideoStream(0).start() + time_start = time.time() + frame_index = 0 + while True: + frame = video.read() + if frame is not None: + original_image_rgb = np.flip(frame, 2) # in [0, 255] (flip required to get RGB) + image = transform({"image": original_image_rgb/255})["image"] + + prediction = process(device, model, model_type, image, (net_w, net_h), + original_image_rgb.shape[1::-1], optimize, True) + + original_image_bgr = np.flip(original_image_rgb, 2) if side else None + content = create_side_by_side(original_image_bgr, prediction, grayscale) + cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', content/255) + + if output_path is not None: + filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index)) + cv2.imwrite(filename + ".png", content) + + alpha = 0.1 + if time.time()-time_start > 0: + fps = (1 - alpha) * fps + alpha * 1 / (time.time()-time_start) # exponential moving average + time_start = time.time() + print(f"\rFPS: {round(fps,2)}", end="") + + if cv2.waitKey(1) == 27: # Escape key + break + + frame_index += 1 + print() + + print("Finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default=None, + help='Folder with input images (if no input path is specified, images are tried to be grabbed ' + 'from camera)' + ) + + parser.add_argument('-o', '--output_path', + default=None, + help='Folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default=None, + help='Path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='dpt_beit_large_512', + help='Model type: ' + 'dpt_beit_large_512, dpt_beit_large_384, dpt_beit_base_384, dpt_swin2_large_384, ' + 'dpt_swin2_base_384, dpt_swin2_tiny_256, dpt_swin_large_384, dpt_next_vit_large_384, ' + 'dpt_levit_224, dpt_large_384, dpt_hybrid_384, midas_v21_384, midas_v21_small_256 or ' + 'openvino_midas_v21_small_256' + ) + + parser.add_argument('-s', '--side', + action='store_true', + help='Output images contain RGB and depth images side by side' + ) + + parser.add_argument('--optimize', dest='optimize', action='store_true', help='Use half-float optimization') + parser.set_defaults(optimize=False) + + parser.add_argument('--height', + type=int, default=None, + help='Preferred height of images feed into the encoder during inference. Note that the ' + 'preferred height may differ from the actual height, because an alignment to multiples of ' + '32 takes place. Many models support only the height chosen during training, which is ' + 'used automatically if this parameter is not set.' + ) + parser.add_argument('--square', + action='store_true', + help='Option to resize images to a square resolution by changing their widths when images are ' + 'fed into the encoder during inference. If this parameter is not set, the aspect ratio of ' + 'images is tried to be preserved if supported by the model.' + ) + parser.add_argument('--grayscale', + action='store_true', + help='Use a grayscale colormap instead of the inferno one. Although the inferno colormap, ' + 'which is used by default, is better for visibility, it does not allow storing 16-bit ' + 'depth values in PNGs but only 8-bit ones due to the precision limitation of this ' + 'colormap.' + ) + + args = parser.parse_args() + + + if args.model_weights is None: + args.model_weights = default_models[args.model_type] + + # set torch options + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height, + args.square, args.grayscale) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5b5fe0e63668eab45a55b140826cb3762862b17c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md @@ -0,0 +1,147 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +### TensorFlow inference using `.pb` and `.onnx` models + +1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow) + +2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow) + +3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file) + + +### Run inference on TensorFlow-model by using TensorFlow + +1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb) +and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_pb.py + ``` + + Or run the small model: + + ```shell + python tf/run_pb.py --model_weights model-small.pb --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + +### Run inference on ONNX-model by using ONNX-Runtime + +1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx) +and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX Runtime +pip install onnxruntime==1.5.2 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_onnx.py + ``` + + Or run the small model: + + ```shell + python tf/run_onnx.py --model_weights model-small.onnx --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + + +### Make ONNX model from downloaded Pytorch model file + +1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the +file in the root folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install PyTorch TorchVision +pip install -I torch==1.7.0 torchvision==0.8.0 + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX-TensorFlow +git clone https://github.com/onnx/onnx-tensorflow.git +cd onnx-tensorflow +git checkout 095b51b88e35c4001d70f15f80f31014b592b81e +pip install -e . +``` + +#### Usage + +1) Run the converter: + + ```shell + python tf/make_onnx_model.py + ``` + +2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder. + + +### Requirements + + The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0. + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2019, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + +### License + +MIT License + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b0e4e1d2ea70fa315fd7ca7dfd72440a19376 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py @@ -0,0 +1,112 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import ntpath +import glob +import torch +import utils +import cv2 +import numpy as np +from torchvision.transforms import Compose, Normalize +from torchvision import transforms + +from shutil import copyfile +import fileinput +import sys +sys.path.append(os.getcwd() + '/..') + +def modify_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename, modify_filename+'.bak') + + with open(modify_filename, 'r') as file : + filedata = file.read() + + filedata = filedata.replace('align_corners=True', 'align_corners=False') + filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') + filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') + + with open(modify_filename, 'w') as file: + file.write(filedata) + +def restore_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename+'.bak', modify_filename) + +modify_file() + +from midas.midas_net import MidasNet +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +restore_file() + + +class MidasNet_preprocessing(MidasNet): + """Network for monocular depth estimation. + """ + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + return MidasNet.forward(self, x) + + +def run(model_path): + """Run MonoDepthNN to compute depth maps. + + Args: + model_path (str): path to saved model + """ + print("initialize") + + # select device + + # load network + #model = MidasNet(model_path, non_negative=True) + model = MidasNet_preprocessing(model_path, non_negative=True) + + model.eval() + + print("start processing") + + # input + img_input = np.zeros((3, 384, 384), np.float32) + + # compute + with torch.no_grad(): + sample = torch.from_numpy(img_input).unsqueeze(0) + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img_input.shape[:2], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) + + print("finished") + + +if __name__ == "__main__": + # set paths + # MODEL_PATH = "model.pt" + MODEL_PATH = "../model-f6b98070.pt" + + # compute depth maps + run(MODEL_PATH) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..7107b99969a127f951814f743d5c562a436b2430 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py @@ -0,0 +1,119 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import sys +import numpy as np +import argparse + +import onnx +import onnxruntime as rt + +from transforms import Resize, NormalizeImage, PrepareForNet + + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # select device + device = "CUDA:0" + #device = "CPU" + print("device: %s" % device) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + print("loading model...") + model = rt.InferenceSession(model_path) + input_name = model.get_inputs()[0].name + output_name = model.get_outputs()[0].name + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0] + prediction = np.array(output).reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.onnx', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py new file mode 100644 index 0000000000000000000000000000000000000000..e46254f7b37f72e7d87672d70fd4b2f393ad7658 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py @@ -0,0 +1,135 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import argparse + +import tensorflow as tf + +from transforms import Resize, NormalizeImage, PrepareForNet + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + for gpu in gpus: + #tf.config.experimental.set_memory_growth(gpu, True) + tf.config.experimental.set_virtual_device_configuration(gpu, + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)]) + except RuntimeError as e: + print(e) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(model_path, 'rb') as f: + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + + model_operations = tf.compat.v1.get_default_graph().get_operations() + input_node = '0:0' + output_layer = model_operations[len(model_operations) - 1].name + ':0' + print("Last layer name: ", output_layer) + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + with tf.compat.v1.Session() as sess: + try: + # load images + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + prob_tensor = sess.graph.get_tensor_by_name(output_layer) + prediction, = sess.run(prob_tensor, {input_node: [img_input] }) + prediction = prediction.reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + except KeyError: + print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".") + exit(-1) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.pb', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a54bd55f5e31a90fad21242efbfda5a6cc1a7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py @@ -0,0 +1,82 @@ +import numpy as np +import sys +import cv2 + + +def write_pfm(path, image, scale=1): + """Write pfm file. + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + +def read_image(path): + """Read image and output RGB image (0-1). + Args: + path (str): path to file + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = 0 + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3976fd97dfe6a9dc7d4fa144be8fcb0b18b2db --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py @@ -0,0 +1,199 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, grayscale, bits=1): + """Write depth map to png file. + + Args: + path (str): filepath without extension + depth (array): depth + grayscale (bool): use a grayscale colormap? + """ + if not grayscale: + bits = 1 + + if not np.isfinite(depth).all(): + depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0) + print("WARNING: Non-finite depth values present") + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.dtype) + + if not grayscale: + out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/src/flux/annotator/zoe/zoedepth/models/builder.py b/src/flux/annotator/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/src/flux/annotator/zoe/zoedepth/models/depth_model.py b/src/flux/annotator/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/src/flux/annotator/zoe/zoedepth/models/model_io.py b/src/flux/annotator/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from zoedepth.models.depth_model import DepthModel +from zoedepth.models.base_models.midas import MidasCore +from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed +from zoedepth.models.layers.dist_layers import ConditionalLogBinomial +from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder +from zoedepth.models.model_io import load_state_from_resource + + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py @@ -0,0 +1,326 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os +import uuid +import warnings +from datetime import datetime as dt +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import wandb +from tqdm import tqdm + +from zoedepth.utils.config import flatten +from zoedepth.utils.misc import RunningAverageDict, colorize, colors + + +def is_rank_zero(args): + return args.rank == 0 + + +class BaseTrainer: + def __init__(self, config, model, train_loader, test_loader=None, device=None): + """ Base Trainer class for training a model.""" + + self.config = config + self.metric_criterion = "abs_rel" + if device is None: + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = device + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + + def resize_to_target(self, prediction, target): + if prediction.shape[2:] != target.shape[-2:]: + prediction = nn.functional.interpolate( + prediction, size=target.shape[-2:], mode="bilinear", align_corners=True + ) + return prediction + + def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): + import glob + import os + + from zoedepth.models.model_io import load_wts + + if hasattr(self.config, "checkpoint"): + checkpoint = self.config.checkpoint + elif hasattr(self.config, "ckpt_pattern"): + pattern = self.config.ckpt_pattern + matches = glob.glob(os.path.join( + checkpoint_dir, f"*{pattern}*{ckpt_type}*")) + if not (len(matches) > 0): + raise ValueError(f"No matches found for the pattern {pattern}") + checkpoint = matches[0] + else: + return + model = load_wts(self.model, checkpoint) + # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it. + print("Loaded weights from {0}".format(checkpoint)) + warnings.warn( + "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") + self.model = model + + def init_optimizer(self): + m = self.model.module if self.config.multigpu else self.model + + if self.config.same_lr: + print("Using same LR") + if hasattr(m, 'core'): + m.core.unfreeze() + params = self.model.parameters() + else: + print("Using diff LR") + if not hasattr(m, 'get_lr_params'): + raise NotImplementedError( + f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") + + params = m.get_lr_params(self.config.lr) + + return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) + + def init_scheduler(self): + lrs = [l['lr'] for l in self.optimizer.param_groups] + return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), + cycle_momentum=self.config.cycle_momentum, + base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) + + def train_on_batch(self, batch, train_step): + raise NotImplementedError + + def validate_on_batch(self, batch, val_step): + raise NotImplementedError + + def raise_if_nan(self, losses): + for key, value in losses.items(): + if torch.isnan(value): + raise ValueError(f"{key} is NaN, Stopping training") + + @property + def iters_per_epoch(self): + return len(self.train_loader) + + @property + def total_iters(self): + return self.config.epochs * self.iters_per_epoch + + def should_early_stop(self): + if self.config.get('early_stop', False) and self.step > self.config.early_stop: + return True + + def train(self): + print(f"Training {self.config.name}") + if self.config.uid is None: + self.config.uid = str(uuid.uuid4()).split('-')[-1] + run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" + self.config.run_id = run_id + self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" + self.should_write = ((not self.config.distributed) + or self.config.rank == 0) + self.should_log = self.should_write # and logging + if self.should_log: + tags = self.config.tags.split( + ',') if self.config.tags != '' else None + wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, + tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) + + self.model.train() + self.step = 0 + best_loss = np.inf + validate_every = int(self.config.validate_every * self.iters_per_epoch) + + + if self.config.prefetch: + + for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): + pass + + losses = {} + def stringify_losses(L): return "; ".join(map( + lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) + for epoch in range(self.config.epochs): + if self.should_early_stop(): + break + + self.epoch = epoch + ################################# Train loop ########################################################## + if self.should_log: + wandb.log({"Epoch": epoch}, step=self.step) + pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) + for i, batch in pbar: + if self.should_early_stop(): + print("Early stopping") + break + # print(f"Batch {self.step+1} on rank {self.config.rank}") + losses = self.train_on_batch(batch, i) + # print(f"trained batch {self.step+1} on rank {self.config.rank}") + + self.raise_if_nan(losses) + if is_rank_zero(self.config) and self.config.print_losses: + pbar.set_description( + f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") + self.scheduler.step() + + if self.should_log and self.step % 50 == 0: + wandb.log({f"Train/{name}": loss.item() + for name, loss in losses.items()}, step=self.step) + + self.step += 1 + + ######################################################################################################## + + if self.test_loader: + if (self.step % validate_every) == 0: + self.model.eval() + if self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_latest.pt") + + ################################# Validation loop ################################################## + # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log( + {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) + + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + if self.config.distributed: + dist.barrier() + # print(f"Validated: {metrics} on device {self.config.rank}") + + # print(f"Finished step {self.step} on device {self.config.rank}") + ################################################################################################# + + # Save / validate at the end + self.step += 1 # log as final point + self.model.eval() + self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") + if self.test_loader: + + ################################# Validation loop ################################################## + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log({f"Test/{name}": tloss for name, + tloss in test_losses.items()}, step=self.step) + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + def validate(self): + with torch.no_grad(): + losses_avg = RunningAverageDict() + metrics_avg = RunningAverageDict() + for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): + metrics, losses = self.validate_on_batch(batch, val_step=i) + + if losses: + losses_avg.update(losses) + if metrics: + metrics_avg.update(metrics) + + return metrics_avg.get_value(), losses_avg.get_value() + + def save_checkpoint(self, filename): + if not self.should_write: + return + root = self.config.save_dir + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + m = self.model.module if self.config.multigpu else self.model + torch.save( + { + "model": m.state_dict(), + "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size + "epoch": self.epoch + }, fpath) + + def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): + if not self.should_log: + return + + if min_depth is None: + try: + min_depth = self.config.min_depth + max_depth = self.config.max_depth + except AttributeError: + min_depth = None + max_depth = None + + depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) + for k, v in depth.items()} + scalar_field = {k: colorize( + v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} + images = {**rgb, **depth, **scalar_field} + wimages = { + prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} + wandb.log(wimages, step=self.step) + + def log_line_plot(self, data): + if not self.should_log: + return + + plt.plot(data) + plt.ylabel("Scale factors") + wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) + plt.close() + + def log_bar_plot(self, title, labels, values): + if not self.should_log: + return + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=["label", "value"]) + wandb.log({title: wandb.plot.bar(table, "label", + "value", title=title)}, step=self.step) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/builder.py b/src/flux/annotator/zoe/zoedepth/trainers/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/builder.py @@ -0,0 +1,48 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module + + +def get_trainer(config): + """Builds and returns a trainer based on the config. + + Args: + config (dict): the config dict (typically constructed using utils.config.get_config) + config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module + + Raises: + ValueError: If the specified trainer does not exist under trainers/ folder + + Returns: + Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object + """ + assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format( + config) + try: + Trainer = getattr(import_module( + f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer') + except ModuleNotFoundError as e: + raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e + return Trainer diff --git a/src/flux/annotator/zoe/zoedepth/trainers/loss.py b/src/flux/annotator/zoe/zoedepth/trainers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/loss.py @@ -0,0 +1,316 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.amp as amp +import numpy as np + + +KEY_OUTPUT = 'metric_depth' + + +def extract_key(prediction, key): + if isinstance(prediction, dict): + return prediction[key] + return prediction + + +# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) +class SILogLoss(nn.Module): + """SILog loss (pixel-wise)""" + def __init__(self, beta=0.15): + super(SILogLoss, self).__init__() + self.name = 'SILog' + self.beta = beta + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + if target.ndim == 3: + target = target.unsqueeze(1) + + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + input = input[mask] + target = target[mask] + + with amp.autocast(enabled=False): # amp causes NaNs in this loss function + alpha = 1e-7 + g = torch.log(input + alpha) - torch.log(target + alpha) + + # n, c, h, w = g.shape + # norm = 1/(h*w) + # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 + + Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) + + loss = 10 * torch.sqrt(Dg) + + if torch.isnan(loss): + print("Nan SILog loss") + print("input:", input.shape) + print("target:", target.shape) + print("G", torch.sum(torch.isnan(g))) + print("Input min max", torch.min(input), torch.max(input)) + print("Target min max", torch.min(target), torch.max(target)) + print("Dg", torch.isnan(Dg)) + print("loss", torch.isnan(loss)) + + if not return_interpolated: + return loss + + return loss, intr_input + + +def grad(x): + # x.shape : n, c, h, w + diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] + diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] + mag = diff_x**2 + diff_y**2 + # angle_ratio + angle = torch.atan(diff_y / (diff_x + 1e-10)) + return mag, angle + + +def grad_mask(mask): + return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:] + + +class GradL1Loss(nn.Module): + """Gradient loss""" + def __init__(self): + super(GradL1Loss, self).__init__() + self.name = 'GradL1' + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + grad_gt = grad(target) + grad_pred = grad(input) + mask_g = grad_mask(mask) + + loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g]) + loss = loss + \ + nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g]) + if not return_interpolated: + return loss + return loss, intr_input + + +class OrdinalRegressionLoss(object): + + def __init__(self, ord_num, beta, discretization="SID"): + self.ord_num = ord_num + self.beta = beta + self.discretization = discretization + + def _create_ord_label(self, gt): + N,one, H, W = gt.shape + # print("gt shape:", gt.shape) + + ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device) + if self.discretization == "SID": + label = self.ord_num * torch.log(gt) / np.log(self.beta) + else: + label = self.ord_num * (gt - 1.0) / (self.beta - 1.0) + label = label.long() + mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \ + .view(1, self.ord_num, 1, 1).to(gt.device) + mask = mask.repeat(N, 1, H, W).contiguous().long() + mask = (mask > label) + ord_c0[mask] = 0 + ord_c1 = 1 - ord_c0 + # implementation according to the paper. + # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device) + # ord_label[:, 0::2, :, :] = ord_c0 + # ord_label[:, 1::2, :, :] = ord_c1 + # reimplementation for fast speed. + ord_label = torch.cat((ord_c0, ord_c1), dim=1) + return ord_label, mask + + def __call__(self, prob, gt): + """ + :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor + :param gt: depth ground truth, NXHxW, torch.Tensor + :return: loss: loss value, torch.float + """ + # N, C, H, W = prob.shape + valid_mask = gt > 0. + ord_label, mask = self._create_ord_label(gt) + # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape)) + entropy = -prob * ord_label + loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)] + return loss.mean() + + +class DiscreteNLLLoss(nn.Module): + """Cross entropy loss""" + def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64): + super(DiscreteNLLLoss, self).__init__() + self.name = 'CrossEntropy' + self.ignore_index = -(depth_bins + 1) + # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index) + self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_bins = depth_bins + self.alpha = 1 + self.zeta = 1 - min_depth + self.beta = max_depth + self.zeta + + def quantize_depth(self, depth): + # depth : N1HW + # output : NCHW + + # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins + depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha) + depth = depth * (self.depth_bins - 1) + depth = torch.round(depth) + depth = depth.long() + return depth + + + + def _dequantize_depth(self, depth): + """ + Inverse of quantization + depth : NCHW -> N1HW + """ + # Get the center of the bin + + + + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + # assert torch.all(input <= 0), "Input should be negative" + + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + # assert torch.all(input)<=1) + if target.ndim == 3: + target = target.unsqueeze(1) + + target = self.quantize_depth(target) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Set the mask to ignore_index + mask = mask.long() + input = input * mask + (1 - mask) * self.ignore_index + target = target * mask + (1 - mask) * self.ignore_index + + + + input = input.flatten(2) # N, nbins, H*W + target = target.flatten(1) # N, H*W + loss = self._loss_func(input, target) + + if not return_interpolated: + return loss + return loss, intr_input + + + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + + def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False): + + if prediction.shape[-1] != target.shape[-1] and interpolate: + prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = prediction + else: + intr_input = prediction + + + prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze() + assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}." + + scale, shift = compute_scale_and_shift(prediction, target, mask) + + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + if not return_interpolated: + return loss + return loss, intr_input + + + + +if __name__ == '__main__': + # Tests for DiscreteNLLLoss + celoss = DiscreteNLLLoss() + print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, )) + + d = torch.Tensor([6.59, 3.8, 10.0]) + print(celoss.dequantize_depth(celoss.quantize_depth(d))) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics + +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.domain_classifier_loss = nn.CrossEntropyLoss() + + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + + Assumes all images in a batch are from the same dataset + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1 + dataset = batch['dataset'][0] + # Convert to 0s or 1s + domain_labels = torch.Tensor([dataset == 'kitti' for _ in range( + images.size(0))]).to(torch.long).to(self.device) + + # m = self.model.module if self.config.multigpu else self.model + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + output = self.model(images) + pred_depths = output['metric_depth'] + domain_logits = output['domain_logits'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + if self.config.w_domain > 0: + l_domain = self.domain_classifier_loss( + domain_logits, domain_labels) + loss = loss + self.config.w_domain * l_domain + losses["DomainLoss"] = l_domain + else: + l_domain = torch.Tensor([0.]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + return losses + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(images)["metric_depth"] + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + mask = torch.logical_and( + depths_gt > self.config.min_depth, depths_gt < self.config.max_depth) + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py @@ -0,0 +1,177 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics +from zoedepth.data.preprocess import get_black_border + +from .base_trainer import BaseTrainer +from torchvision import transforms +from PIL import Image +import numpy as np + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + dataset = batch['dataset'][0] + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + + output = self.model(images) + pred_depths = output['metric_depth'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + # -99 is treated as invalid depth in the log_images function and is colored grey. + depths_gt[torch.logical_not(mask)] = -99 + + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + if self.config.get("log_rel", False): + self.log_images( + scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel") + + self.scaler.update() + self.optimizer.zero_grad() + + return losses + + @torch.no_grad() + def eval_infer(self, x): + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(x)['metric_depth'] + return pred_depths + + @torch.no_grad() + def crop_aware_infer(self, x): + # if we are not avoiding the black border, we can just use the normal inference + if not self.config.get("avoid_boundary", False): + return self.eval_infer(x) + + # otherwise, we need to crop the image to avoid the black border + # For now, this may be a bit slow due to converting to numpy and back + # We assume no normalization is done on the input image + + # get the black border + assert x.shape[0] == 1, "Only batch size 1 is supported for now" + x_pil = transforms.ToPILImage()(x[0].cpu()) + x_np = np.array(x_pil, dtype=np.uint8) + black_border_params = get_black_border(x_np) + top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right + x_np_cropped = x_np[top:bottom, left:right, :] + x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) + + # run inference on the cropped image + pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) + + # resize the prediction to x_np_cropped's size + pred_depths_cropped = nn.functional.interpolate( + pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) + + + # pad the prediction back to the original size + pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) + pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped + + return pred_depths + + + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + mask = batch["mask"].to(self.device) + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + mask = mask.squeeze().unsqueeze(0).unsqueeze(0) + if dataset == 'nyu': + pred_depths = self.crop_aware_infer(images) + else: + pred_depths = self.eval_infer(images) + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/src/flux/annotator/zoe/zoedepth/utils/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/src/flux/annotator/zoe/zoedepth/utils/config.py b/src/flux/annotator/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/utils/geometry.py b/src/flux/annotator/zoe/zoedepth/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/geometry.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np + +def get_intrinsics(H,W): + """ + Intrinsics for a pinhole camera model. + Assume fov of 55 degrees and central principal point. + """ + f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) + cx = 0.5 * W + cy = 0.5 * H + return np.array([[f, 0, cx], + [0, f, cy], + [0, 0, 1]]) + +def depth_to_points(depth, R=None, t=None): + + K = get_intrinsics(depth.shape[1], depth.shape[2]) + Kinv = np.linalg.inv(K) + if R is None: + R = np.eye(3) + if t is None: + t = np.zeros(3) + + # M converts from your coordinate to PyTorch3D's coordinate system + M = np.eye(3) + M[0, 0] = -1.0 + M[1, 1] = -1.0 + + height, width = depth.shape[1:3] + + x = np.arange(width) + y = np.arange(height) + coord = np.stack(np.meshgrid(x, y), -1) + coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 + coord = coord.astype(np.float32) + # coord = torch.as_tensor(coord, dtype=torch.float32, device=device) + coord = coord[None] # bs, h, w, 3 + + D = depth[:, :, :, None, None] + # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape ) + pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None] + # pts3D_1 live in your coordinate system. Convert them to Py3D's + pts3D_1 = M[None, None, None, ...] @ pts3D_1 + # from reference to targe tviewpoint + pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None] + # pts3D_2 = pts3D_1 + # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w + return pts3D_2[:, :, :, :3, 0][0] + + +def create_triangles(h, w, mask=None): + """ + Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68 + Creates mesh triangle indices from a given pixel grid size. + This function is not and need not be differentiable as triangle indices are + fixed. + Args: + h: (int) denoting the height of the image. + w: (int) denoting the width of the image. + Returns: + triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3) + """ + x, y = np.meshgrid(range(w - 1), range(h - 1)) + tl = y * w + x + tr = y * w + x + 1 + bl = (y + 1) * w + x + br = (y + 1) * w + x + 1 + triangles = np.array([tl, bl, tr, br, tr, bl]) + triangles = np.transpose(triangles, (1, 2, 0)).reshape( + ((w - 1) * (h - 1) * 2, 3)) + if mask is not None: + mask = mask.reshape(-1) + triangles = triangles[mask[triangles].all(1)] + return triangles diff --git a/src/flux/annotator/zoe/zoedepth/utils/misc.py b/src/flux/annotator/zoe/zoedepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/misc.py @@ -0,0 +1,368 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +"""Miscellaneous utility functions.""" + +from scipy import ndimage + +import base64 +import math +import re +from io import BytesIO + +import matplotlib +import matplotlib.cm +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.nn +import torch.nn as nn +import torch.utils.data.distributed +from PIL import Image +from torchvision.transforms import ToTensor + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + + +class RunningAverageDict: + """A dictionary of running averages.""" + def __init__(self): + self._dict = None + + def update(self, new_dict): + if new_dict is None: + return + + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + if self._dict is None: + return None + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + # squeeze last dim if it exists + # grey out the invalid values + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + # return img.transpose((2, 0, 1)) + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + +def count_parameters(model, include_all=False): + return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all) + + +def compute_errors(gt, pred): + """Compute metrics for 'pred' compared to 'gt' + + Args: + gt (numpy.ndarray): Ground truth values + pred (numpy.ndarray): Predicted values + + gt.shape should be equal to pred.shape + + Returns: + dict: Dictionary containing the following metrics: + 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25 + 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2 + 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3 + 'abs_rel': Absolute relative error + 'rmse': Root mean squared error + 'log_10': Absolute log10 error + 'sq_rel': Squared relative error + 'rmse_log': Root mean squared error on the log scale + 'silog': Scale invariant log error + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs): + """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics. + """ + if 'config' in kwargs: + config = kwargs['config'] + garg_crop = config.garg_crop + eigen_crop = config.eigen_crop + min_depth_eval = config.min_depth_eval + max_depth_eval = config.max_depth_eval + + if gt.shape[-2:] != pred.shape[-2:] and interpolate: + pred = nn.functional.interpolate( + pred, gt.shape[-2:], mode='bilinear', align_corners=True) + + pred = pred.squeeze().cpu().numpy() + pred[pred < min_depth_eval] = min_depth_eval + pred[pred > max_depth_eval] = max_depth_eval + pred[np.isinf(pred)] = max_depth_eval + pred[np.isnan(pred)] = min_depth_eval + + gt_depth = gt.squeeze().cpu().numpy() + valid_mask = np.logical_and( + gt_depth > min_depth_eval, gt_depth < max_depth_eval) + + if garg_crop or eigen_crop: + gt_height, gt_width = gt_depth.shape + eval_mask = np.zeros(valid_mask.shape) + + if garg_crop: + eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), + int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 + + elif eigen_crop: + # print("-"*10, " EIGEN CROP ", "-"*10) + if dataset == 'kitti': + eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), + int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 + else: + # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images" + eval_mask[45:471, 41:601] = 1 + else: + eval_mask = np.ones(valid_mask.shape) + valid_mask = np.logical_and(valid_mask, eval_mask) + return compute_errors(gt_depth[valid_mask], pred[valid_mask]) + + +#################################### Model uilts ################################################ + + +def parallelize(config, model, find_unused_parameters=True): + + if config.gpu is not None: + torch.cuda.set_device(config.gpu) + model = model.cuda(config.gpu) + + config.multigpu = False + if config.distributed: + # Use DDP + config.multigpu = True + config.rank = config.rank * config.ngpus_per_node + config.gpu + dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, + world_size=config.world_size, rank=config.rank) + config.batch_size = int(config.batch_size / config.ngpus_per_node) + # config.batch_size = 8 + config.workers = int( + (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node) + print("Device", config.gpu, "Rank", config.rank, "batch size", + config.batch_size, "Workers", config.workers) + torch.cuda.set_device(config.gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.cuda(config.gpu) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu, + find_unused_parameters=find_unused_parameters) + + elif config.gpu is None: + # Use DP + config.multigpu = True + model = model.cuda() + model = torch.nn.DataParallel(model) + + return model + + +################################################################################################# + + +##################################################################################################### + + +class colors: + '''Colors class: + Reset all colors with colors.reset + Two subclasses fg for foreground and bg for background. + Use as colors.subclass.colorname. + i.e. colors.fg.red or colors.bg.green + Also, the generic bold, disable, underline, reverse, strikethrough, + and invisible work with the main class + i.e. colors.bold + ''' + reset = '\033[0m' + bold = '\033[01m' + disable = '\033[02m' + underline = '\033[04m' + reverse = '\033[07m' + strikethrough = '\033[09m' + invisible = '\033[08m' + + class fg: + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + orange = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + lightgrey = '\033[37m' + darkgrey = '\033[90m' + lightred = '\033[91m' + lightgreen = '\033[92m' + yellow = '\033[93m' + lightblue = '\033[94m' + pink = '\033[95m' + lightcyan = '\033[96m' + + class bg: + black = '\033[40m' + red = '\033[41m' + green = '\033[42m' + orange = '\033[43m' + blue = '\033[44m' + purple = '\033[45m' + cyan = '\033[46m' + lightgrey = '\033[47m' + + +def printc(text, color): + print(f"{color}{text}{colors.reset}") + +############################################ + +def get_image_from_url(url): + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + return img + +def url_to_torch(url, size=(384, 384)): + img = get_image_from_url(url) + img = img.resize(size, Image.ANTIALIAS) + img = torch.from_numpy(np.asarray(img)).float() + img = img.permute(2, 0, 1) + img.div_(255) + return img + +def pil_to_batched_tensor(img): + return ToTensor()(img).unsqueeze(0) + +def save_raw_16bit(depth, fpath="raw.png"): + if isinstance(depth, torch.Tensor): + depth = depth.squeeze().cpu().numpy() + + assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" + assert depth.ndim == 2, "Depth must be 2D" + depth = depth * 256 # scale for 16-bit png + depth = depth.astype(np.uint16) + depth = Image.fromarray(depth) + depth.save(fpath) + print("Saved raw depth to", fpath) \ No newline at end of file diff --git a/src/flux/api.py b/src/flux/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836 --- /dev/null +++ b/src/flux/api.py @@ -0,0 +1,194 @@ +import io +import os +import time +from pathlib import Path + +import requests +from PIL import Image + +API_ENDPOINT = "https://api.bfl.ml" + + +class ApiException(Exception): + def __init__(self, status_code: int, detail: str | list[dict] | None = None): + super().__init__() + self.detail = detail + self.status_code = status_code + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.detail is None: + message = None + elif isinstance(self.detail, str): + message = self.detail + else: + message = "[" + ",".join(d["msg"] for d in self.detail) + "]" + return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" + + +class ImageRequest: + def __init__( + self, + prompt: str, + width: int = 1024, + height: int = 1024, + name: str = "flux.1-pro", + num_steps: int = 50, + prompt_upsampling: bool = False, + seed: int | None = None, + validate: bool = True, + launch: bool = True, + api_key: str | None = None, + ): + """ + Manages an image generation request to the API. + + Args: + prompt: Prompt to sample + width: Width of the image in pixel + height: Height of the image in pixel + name: Name of the model + num_steps: Number of network evaluations + prompt_upsampling: Use prompt upsampling + seed: Fix the generation seed + validate: Run input validation + launch: Directly launches request + api_key: Your API key if not provided by the environment + + Raises: + ValueError: For invalid input + ApiException: For errors raised from the API + """ + if validate: + if name not in ["flux.1-pro"]: + raise ValueError(f"Invalid model {name}") + elif width % 32 != 0: + raise ValueError(f"width must be divisible by 32, got {width}") + elif not (256 <= width <= 1440): + raise ValueError(f"width must be between 256 and 1440, got {width}") + elif height % 32 != 0: + raise ValueError(f"height must be divisible by 32, got {height}") + elif not (256 <= height <= 1440): + raise ValueError(f"height must be between 256 and 1440, got {height}") + elif not (1 <= num_steps <= 50): + raise ValueError(f"steps must be between 1 and 50, got {num_steps}") + + self.request_json = { + "prompt": prompt, + "width": width, + "height": height, + "variant": name, + "steps": num_steps, + "prompt_upsampling": prompt_upsampling, + } + if seed is not None: + self.request_json["seed"] = seed + + self.request_id: str | None = None + self.result: dict | None = None + self._image_bytes: bytes | None = None + self._url: str | None = None + if api_key is None: + self.api_key = os.environ.get("BFL_API_KEY") + else: + self.api_key = api_key + + if launch: + self.request() + + def request(self): + """ + Request to generate the image. + """ + if self.request_id is not None: + return + response = requests.post( + f"{API_ENDPOINT}/v1/image", + headers={ + "accept": "application/json", + "x-key": self.api_key, + "Content-Type": "application/json", + }, + json=self.request_json, + ) + result = response.json() + if response.status_code != 200: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + self.request_id = response.json()["id"] + + def retrieve(self) -> dict: + """ + Wait for the generation to finish and retrieve response. + """ + if self.request_id is None: + self.request() + while self.result is None: + response = requests.get( + f"{API_ENDPOINT}/v1/get_result", + headers={ + "accept": "application/json", + "x-key": self.api_key, + }, + params={ + "id": self.request_id, + }, + ) + result = response.json() + if "status" not in result: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + elif result["status"] == "Ready": + self.result = result["result"] + elif result["status"] == "Pending": + time.sleep(0.5) + else: + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") + return self.result + + @property + def bytes(self) -> bytes: + """ + Generated image as bytes. + """ + if self._image_bytes is None: + response = requests.get(self.url) + if response.status_code == 200: + self._image_bytes = response.content + else: + raise ApiException(status_code=response.status_code) + return self._image_bytes + + @property + def url(self) -> str: + """ + Public url to retrieve the image from + """ + if self._url is None: + result = self.retrieve() + self._url = result["sample"] + return self._url + + @property + def image(self) -> Image.Image: + """ + Load the image as a PIL Image + """ + return Image.open(io.BytesIO(self.bytes)) + + def save(self, path: str): + """ + Save the generated image to a local path + """ + suffix = Path(self.url).suffix + if not path.endswith(suffix): + path = path + suffix + Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as file: + file.write(self.bytes) + + +if __name__ == "__main__": + from fire import Fire + + Fire(ImageRequest) diff --git a/src/flux/cli.py b/src/flux/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f3624bc6c387f359162e68f46995b12ce341970a --- /dev/null +++ b/src/flux/cli.py @@ -0,0 +1,254 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from einops import rearrange +from fire import Fire +from PIL import ExifTags, Image + +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import (configs, embed_watermark, load_ae, load_clip, + load_flow_model, load_t5) +from transformers import pipeline + +NSFW_THRESHOLD = 0.85 + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting seed to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +@torch.inference_mode() +def main( + name: str = "flux-schnell", + width: int = 1360, + height: int = 768, + seed: int | None = None, + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 3.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (height // 16) + width = 16 * (width // 16) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare(t5, clip, x, prompt=opts.prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/src/flux/controlnet.py b/src/flux/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb --- /dev/null +++ b/src/flux/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/src/flux/math.py b/src/flux/math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/src/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/src/flux/model.py b/src/flux/model.py new file mode 100644 index 0000000000000000000000000000000000000000..51531c114babcea3b7a365ca44ee458bfce9a673 --- /dev/null +++ b/src/flux/model.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + image_proj: Tensor | None = None, + ip_scale: Tensor | float = 1.0, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + image_proj, + ip_scale, + ) + else: + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/src/flux/modules/autoencoder.py b/src/flux/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028 --- /dev/null +++ b/src/flux/modules/autoencoder.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/src/flux/modules/conditioner.py b/src/flux/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdd881878ace848745da7d723c60f03392916ab --- /dev/null +++ b/src/flux/modules/conditioner.py @@ -0,0 +1,38 @@ +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/src/flux/modules/layers.py b/src/flux/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5489698671c6ed32dcb790a2f83d682d898b872 --- /dev/null +++ b/src/flux/modules/layers.py @@ -0,0 +1,595 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope +import torch.nn.functional as F + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + print('2' * 30) + + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + print('1' * 30) + print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class IPDoubleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with double stream block.""" + + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) + + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) + + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) + + def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): + + # Prepare image for attention + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:] + + # print(f"txt_attn shape: {txt_attn.size()}") + # print(f"img_attn shape: {img_attn.size()}") + + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + + + # IP-adapter processing + ip_query = img_q # latent sample query + ip_key = self.ip_adapter_double_stream_k_proj(image_proj) + ip_value = self.ip_adapter_double_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value, + dropout_p=0.0, + is_causal=False + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) + + img = img + ip_scale * ip_attention + + return img, txt + +class DoubleStreamBlockProcessor: + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor = None, + ip_scale: float =1.0, + ) -> tuple[Tensor, Tensor]: + if image_proj is None: + return self.processor(self, img, txt, vec, pe) + else: + return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) + +class IPSingleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with single stream block.""" + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) + + def __call__( + self, + attn: nn.Module, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # IP-adapter processing + ip_query = q + ip_key = self.ip_adapter_single_stream_k_proj(image_proj) + ip_value = self.ip_adapter_single_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") + + attn_out = attn_1 + ip_scale * ip_attention + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) + out = x + mod.gate * output + + return out + + +class SingleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight + output = x + mod.gate * output + return output + + +class SingleStreamBlockProcessor: + def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = x + mod.gate * output + return output + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + processor = SingleStreamBlockProcessor() + self.set_processor(processor) + + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + if image_proj is None: + return self.processor(self, x, vec, pe) + else: + return self.processor(self, x, vec, pe, image_proj, ip_scale) + + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + +class ImageProjModel(torch.nn.Module): + """Projection Model + https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 + """ + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + diff --git a/src/flux/sampling.py b/src/flux/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a97f7d8e5a2586f8c9849f0d6470af9c201041 --- /dev/null +++ b/src/flux/sampling.py @@ -0,0 +1,242 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1.0, + neg_ip_scale: Tensor | float = 1.0 +): + i = 0 + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + i += 1 + return img + +def denoise_controlnet( + model: Flux, + controlnet:None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, +): + # this is ignored for schnell + i = 0 + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples], + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples], + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + + img = img + (t_prev - t_curr) * pred + + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/src/flux/util.py b/src/flux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd496213f0acc9cd05ddc621d504dd9e87373e9 --- /dev/null +++ b/src/flux/util.py @@ -0,0 +1,433 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from optimum.quanto import requantize + +from .model import Flux, FluxParams +from .controlnet import ControlNetFlux +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder +from .annotator.dwpose import DWposeDetector +from .annotator.mlsd import MLSDdetector +from .annotator.canny import CannyDetector +from .annotator.midas import MidasDetector +from .annotator.hed import HEDdetector +from .annotator.tile import TileDetector +from .annotator.zoe import ZoeDetector + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + elif name == "zoe": + processor = ZoeDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = np.array(image) + detect_resolution = max(width, height) + image, remove_pad = resize_image_with_pad(image, detect_resolution) + + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + result = HWC3(remove_pad(result)) + result = cv2.resize(result, (width, height)) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') + + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device='cpu') + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f1325497c3bb67e957c60d672b4527633dd7f547 --- /dev/null +++ b/src/flux/xflux_pipeline.py @@ -0,0 +1,364 @@ +from PIL import Image, ExifTags +import numpy as np +import torch +from torch import Tensor + +from einops import rearrange +import uuid +import os + +from src.flux.modules.layers import ( + SingleStreamBlockProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + DoubleStreamBlockLoraProcessor, + IPDoubleStreamBlockProcessor, + ImageProjModel, +) +from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from src.flux.util import ( + load_ae, + load_clip, + load_flow_model, + load_t5, + load_controlnet, + load_flow_model_quintized, + Annotator, + get_lora_rank, + load_checkpoint +) + +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False): + self.device = torch.device(device) + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.image_encoder_path = "openai/clip-vit-large-patch14" + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + self.ip_loaded = False + + def set_ip(self, local_path: str = None, repo_id = None, name: str = None): + self.model.to(self.device) + + # unpack checkpoint + checkpoint = load_checkpoint(local_path, repo_id, name) + prefix = "double_blocks." + blocks = {} + proj = {} + + for key, value in checkpoint.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + + # setup image embedding projection model + self.improj = ImageProjModel(4096, 768, 4) + self.improj.load_state_dict(proj) + self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + + ip_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + ip_state_dict = {} + for k in checkpoint.keys(): + if name in k: + ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + if ip_state_dict: + ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + ip_attn_procs[name].load_state_dict(ip_state_dict) + ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + else: + ip_attn_procs[name] = self.model.attn_processors[name] + + self.model.set_attn_processor(ip_attn_procs) + self.ip_loaded = True + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + self.control_type = control_type + + def get_image_proj( + self, + image_prompt: Tensor, + ): + # encode image-prompt embeds + image_prompt = self.clip_image_processor( + images=image_prompt, + return_tensors="pt" + ).pixel_values + image_prompt = image_prompt.to(self.image_encoder.device) + image_prompt_embeds = self.image_encoder( + image_prompt + ).image_embeds.to( + device=self.device, dtype=torch.bfloat16, + ) + # encode image + image_proj = self.improj(image_prompt_embeds) + return image_proj + + def __call__(self, + prompt: str, + image_prompt: Image = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_prompt: Image = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + image_proj = None + neg_image_proj = None + if not (image_prompt is None and neg_image_prompt is None) : + assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + if image_prompt is None: + image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + if neg_image_prompt is None: + neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + image_proj = self.get_image_proj(image_prompt) + neg_image_proj = self.get_image_proj(neg_image_prompt) + + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute( + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + @torch.inference_mode() + def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path): + if controlnet_image is not None: + controlnet_image = Image.fromarray(controlnet_image) + if ((self.controlnet_loaded and control_type != self.control_type) + or not self.controlnet_loaded): + if local_path is not None: + self.set_controlnet(control_type, local_path=local_path) + else: + self.set_controlnet(control_type, local_path=None, + repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", + name=f"flux-{control_type}-controlnet-v3.safetensors") + if lora_local_path is not None: + self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) + if image_prompt is not None: + image_prompt = Image.fromarray(image_prompt) + if neg_image_prompt is not None: + neg_image_prompt = Image.fromarray(neg_image_prompt) + if not self.ip_loaded: + if ip_local_path is not None: + self.set_ip(local_path=ip_local_path) + else: + self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", + name="flux-ip-adapter.safetensors") + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + img = self(prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg) + + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "XLabs AI" + exif_data[ExifTags.Base.Model] = self.model_type + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return img, filename + + def forward( + self, + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image = None, + timestep_to_start_cfg = 0, + true_gs = 3.5, + control_weight = 0.9, + neg_prompt="", + image_proj=None, + neg_image_proj=None, + ip_scale=1.0, + neg_ip_scale=1.0, + ): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, + **inp_cond, + controlnet=self.controlnet, + timesteps=timesteps, + guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=control_weight, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + else: + x = denoise( + self.model, + **inp_cond, + timesteps=timesteps, + guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache() + + +class XFluxSampler(XFluxPipeline): + def __init__(self, clip, t5, ae, model, device): + self.clip = clip + self.t5 = t5 + self.ae = ae + self.model = model + self.model.eval() + self.device = device + self.controlnet_loaded = False + self.ip_loaded = False + self.offload = False diff --git a/train_configs/test_canny_controlnet.yaml b/train_configs/test_canny_controlnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b4c7ee36ecf9c560522f0b4545a5d7f5b5861f4 --- /dev/null +++ b/train_configs/test_canny_controlnet.yaml @@ -0,0 +1,25 @@ +model_name: "flux-dev" +data_config: + train_batch_size: 4 + num_workers: 4 + img_size: 512 + img_dir: images/ +report_to: wandb +train_batch_size: 3 +output_dir: saves_canny/ +max_train_steps: 100000 +learning_rate: 2e-5 +lr_scheduler: constant +lr_warmup_steps: 10 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1e-8 +max_grad_norm: 1.0 +logging_dir: logs +mixed_precision: "bf16" +checkpointing_steps: 2500 +checkpoints_total_limit: 10 +tracker_project_name: canny_training +resume_from_checkpoint: latest +gradient_accumulation_steps: 2 diff --git a/train_configs/test_finetune.yaml b/train_configs/test_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..173bfeef507459190ecc8f528991b7f5c86aaab9 --- /dev/null +++ b/train_configs/test_finetune.yaml @@ -0,0 +1,25 @@ +model_name: "flux-dev" +data_config: + train_batch_size: 1 + num_workers: 4 + img_size: 512 + img_dir: images/ +report_to: wandb +train_batch_size: 1 +output_dir: saves/ +max_train_steps: 100000 +learning_rate: 1e-5 +lr_scheduler: constant +lr_warmup_steps: 10 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1e-8 +max_grad_norm: 1.0 +logging_dir: logs +mixed_precision: "bf16" +checkpointing_steps: 2500 +checkpoints_total_limit: 10 +tracker_project_name: finetune_test +resume_from_checkpoint: latest +gradient_accumulation_steps: 2 diff --git a/train_configs/test_lora.yaml b/train_configs/test_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1f5ac0cb98304e92ef5e72b5af9198155cc6100 --- /dev/null +++ b/train_configs/test_lora.yaml @@ -0,0 +1,37 @@ +model_name: "flux-dev" +data_config: + train_batch_size: 1 + num_workers: 4 + img_size: 512 + img_dir: images/ + random_ratio: true # support multi crop preprocessing +report_to: wandb +train_batch_size: 1 +output_dir: lora/ +max_train_steps: 100000 +learning_rate: 1e-5 +lr_scheduler: constant +lr_warmup_steps: 10 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1e-8 +max_grad_norm: 1.0 +logging_dir: logs +mixed_precision: "bf16" +checkpointing_steps: 2500 +checkpoints_total_limit: 10 +tracker_project_name: lora_test +resume_from_checkpoint: latest +gradient_accumulation_steps: 2 +rank: 16 +single_blocks: "1,2,3,4" +double_blocks: null +disable_sampling: false +sample_every: 250 # sample every this many steps +sample_width: 1024 +sample_height: 1024 +sample_steps: 20 +sample_prompts: + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" diff --git a/train_flux_deepspeed.py b/train_flux_deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..d8853e0c169e28ee4b06a6060ac386e552d63124 --- /dev/null +++ b/train_flux_deepspeed.py @@ -0,0 +1,303 @@ +import argparse +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers +from omegaconf import OmegaConf +from copy import deepcopy +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr +from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from src.flux.util import (configs, load_ae, load_clip, + load_flow_model2, load_t5) +from image_datasets.dataset import loader +if is_wandb_available(): + import wandb +logger = get_logger(__name__, log_level="INFO") + +def get_models(name: str, device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + clip.requires_grad_(False) + model = load_flow_model2(name, device="cpu") + vae = load_ae(name, device="cpu" if offload else device) + return model, vae, t5, clip + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--config", + type=str, + default=None, + required=True, + help="path to config", + ) + args = parser.parse_args() + + + return args.config +def main(): + + args = OmegaConf.load(parse_args()) + is_schnell = args.model_name == "flux-schnell" + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) + + vae.requires_grad_(False) + t5.requires_grad_(False) + clip.requires_grad_(False) + dit = dit.to(torch.float32) + dit.train() + optimizer_cls = torch.optim.AdamW + #you can train your own layers + for n, param in dit.named_parameters(): + if 'txt_attn' not in n: + param.requires_grad = False + + print(sum([p.numel() for p in dit.parameters() if p.requires_grad]) / 1000000, 'parameters') + optimizer = optimizer_cls( + [p for p in dit.parameters() if p.requires_grad], + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataloader = loader(**args.data_config) + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + dit_state = torch.load(os.path.join(args.output_dir, path, 'dit.bin'), map_location='cpu') + dit_state2 = {} + for k in dit_state.keys(): + dit_state2[k[len('module.'):]] = dit_state[k] + dit.load_state_dict(dit_state2) + optimizer_state = torch.load(os.path.join(args.output_dir, path, 'optimizer.bin'), map_location='cpu')['base_optimizer_state'] + optimizer.load_state_dict(optimizer_state) + + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + dit, optimizer, _, lr_scheduler = accelerator.prepare( + dit, optimizer, deepcopy(train_dataloader), lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, {"test": None}) + + timesteps = list(torch.linspace(1, 0, 1000).numpy()) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(dit): + img, prompts = batch + with torch.no_grad(): + x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) + inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts) + x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + + bs = img.shape[0] + t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) + x_0 = torch.randn_like(x_1).to(accelerator.device) + x_t = (1 - t) * x_1 + t * x_0 + bsz = x_1.shape[0] + guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype) + + # Predict the noise residual and compute loss + model_pred = dit(img=x_t.to(weight_dtype), + img_ids=inp['img_ids'].to(weight_dtype), + txt=inp['txt'].to(weight_dtype), + txt_ids=inp['txt_ids'].to(weight_dtype), + y=inp['vec'].to(weight_dtype), + timesteps=t.to(weight_dtype), + guidance=guidance_vec.to(weight_dtype),) + + #loss = F.mse_loss(model_pred.float(), (x_1 - x_0).float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + if not os.path.exists(save_path): + os.mkdir(save_path) + torch.save(dit.state_dict(), os.path.join(save_path, 'dit.bin')) + torch.save(optimizer.state_dict(), os.path.join(save_path, 'optimizer.bin')) + #accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_flux_deepspeed_controlnet.py b/train_flux_deepspeed_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8cc690d7f000aa36fba8b3b3c90ed4d37e6ae3 --- /dev/null +++ b/train_flux_deepspeed_controlnet.py @@ -0,0 +1,316 @@ +import argparse +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers +from omegaconf import OmegaConf +from copy import deepcopy +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr +from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from src.flux.util import (configs, load_ae, load_clip, + load_flow_model2, load_controlnet, load_t5) +from image_datasets.canny_dataset import loader +if is_wandb_available(): + import wandb +logger = get_logger(__name__, log_level="INFO") + +def get_models(name: str, device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + model = load_flow_model2(name, device="cpu") + vae = load_ae(name, device="cpu" if offload else device) + return model, vae, t5, clip + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--config", + type=str, + default=None, + required=True, + help="path to config", + ) + args = parser.parse_args() + + + return args.config +def main(): + + args = OmegaConf.load(parse_args()) + is_schnell = args.model_name == "flux-schnell" + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + print("DEVICE", accelerator.device) + dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) + + vae.requires_grad_(False) + t5.requires_grad_(False) + clip.requires_grad_(False) + dit.requires_grad_(False) + dit.to(accelerator.device) + + controlnet = load_controlnet(name=args.model_name, device=accelerator.device, transformer=dit) + controlnet = controlnet.to(torch.float32) + controlnet.train() + + optimizer_cls = torch.optim.AdamW + + print(sum([p.numel() for p in controlnet.parameters() if p.requires_grad]) / 1000000, 'parameters') + optimizer = optimizer_cls( + [p for p in controlnet.parameters() if p.requires_grad], + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataloader = loader(**args.data_config) + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + controlnet, optimizer, _, lr_scheduler = accelerator.prepare( + controlnet, optimizer, deepcopy(train_dataloader), lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, {"test": None}) + + timesteps = list(torch.linspace(1, 0, 1000).numpy()) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + img, control_image, prompts = batch + control_image = control_image.to(accelerator.device) + with torch.no_grad(): + x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) + inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts) + + x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + bs = img.shape[0] + t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) + + x_0 = torch.randn_like(x_1).to(accelerator.device) + print(t.shape, x_1.shape, x_0.shape) + x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0 + bsz = x_1.shape[0] + guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype) + + block_res_samples = controlnet( + img=x_t.to(weight_dtype), + img_ids=inp['img_ids'].to(weight_dtype), + controlnet_cond=control_image.to(weight_dtype), + txt=inp['txt'].to(weight_dtype), + txt_ids=inp['txt_ids'].to(weight_dtype), + y=inp['vec'].to(weight_dtype), + timesteps=t.to(weight_dtype), + guidance=guidance_vec.to(weight_dtype), + ) + # Predict the noise residual and compute loss + model_pred = dit( + img=x_t.to(weight_dtype), + img_ids=inp['img_ids'].to(weight_dtype), + txt=inp['txt'].to(weight_dtype), + txt_ids=inp['txt_ids'].to(weight_dtype), + block_controlnet_hidden_states=[ + sample.to(dtype=weight_dtype) for sample in block_res_samples + ], + y=inp['vec'].to(weight_dtype), + timesteps=t.to(weight_dtype), + guidance=guidance_vec.to(weight_dtype), + ) + + loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(controlnet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + #if not os.path.exists(save_path): + # os.mkdir(save_path) + + accelerator.save_state(save_path) + unwrapped_model = accelerator.unwrap_model(controlnet) + + torch.save(unwrapped_model.state_dict(), os.path.join(save_path, 'controlnet.bin')) + logger.info(f"Saved state to {save_path}") + + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/train_flux_lora_deepspeed.py b/train_flux_lora_deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a09a4e91208561fa2f5944b6ae485b9b86833c --- /dev/null +++ b/train_flux_lora_deepspeed.py @@ -0,0 +1,355 @@ +import argparse +import logging +import math +import os +import re +import random +import shutil +from contextlib import nullcontext +from pathlib import Path +from safetensors.torch import save_file + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers +from omegaconf import OmegaConf +from copy import deepcopy +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr +from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from src.flux.util import (configs, load_ae, load_clip, + load_flow_model2, load_t5) +from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor +from src.flux.xflux_pipeline import XFluxSampler + +from image_datasets.dataset import loader +if is_wandb_available(): + import wandb +logger = get_logger(__name__, log_level="INFO") + +def get_models(name: str, device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + clip.requires_grad_(False) + model = load_flow_model2(name, device="cpu") + vae = load_ae(name, device="cpu" if offload else device) + return model, vae, t5, clip + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--config", + type=str, + default=None, + required=True, + help="path to config", + ) + args = parser.parse_args() + + + return args.config + + +def main(): + args = OmegaConf.load(parse_args()) + is_schnell = args.model_name == "flux-schnell" + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) + lora_attn_procs = {} + + if args.double_blocks is None: + double_blocks_idx = list(range(19)) + else: + double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] + + if args.single_blocks is None: + single_blocks_idx = list(range(38)) + elif args.single_blocks is not None: + single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] + + for name, attn_processor in dit.attn_processors.items(): + match = re.search(r'\.(\d+)\.', name) + if match: + layer_index = int(match.group(1)) + + if name.startswith("double_blocks") and layer_index in double_blocks_idx: + print("setting LoRA Processor for", name) + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( + dim=3072, rank=args.rank + ) + elif name.startswith("single_blocks") and layer_index in single_blocks_idx: + print("setting LoRA Processor for", name) + lora_attn_procs[name] = SingleStreamBlockLoraProcessor( + dim=3072, rank=args.rank + ) + else: + lora_attn_procs[name] = attn_processor + + dit.set_attn_processor(lora_attn_procs) + + vae.requires_grad_(False) + t5.requires_grad_(False) + clip.requires_grad_(False) + dit = dit.to(torch.float32) + dit.train() + optimizer_cls = torch.optim.AdamW + for n, param in dit.named_parameters(): + if '_lora' not in n: + param.requires_grad = False + else: + print(n) + print(sum([p.numel() for p in dit.parameters() if p.requires_grad]) / 1000000, 'parameters') + optimizer = optimizer_cls( + [p for p in dit.parameters() if p.requires_grad], + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataloader = loader(**args.data_config) + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + global_step = 0 + first_epoch = 0 + + dit, optimizer, _, lr_scheduler = accelerator.prepare( + dit, optimizer, deepcopy(train_dataloader), lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, {"test": None}) + + timesteps = get_schedule( + 999, + (1024 // 8) * (1024 // 8) // 4, + shift=True, + ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(dit): + img, prompts = batch + with torch.no_grad(): + x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) + inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts) + x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + bs = img.shape[0] + t = torch.tensor([timesteps[random.randint(0, 999)]]).to(accelerator.device) + x_0 = torch.randn_like(x_1).to(accelerator.device) + x_t = (1 - t) * x_1 + t * x_0 + bsz = x_1.shape[0] + guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype) + + # Predict the noise residual and compute loss + model_pred = dit(img=x_t.to(weight_dtype), + img_ids=inp['img_ids'].to(weight_dtype), + txt=inp['txt'].to(weight_dtype), + txt_ids=inp['txt_ids'].to(weight_dtype), + y=inp['vec'].to(weight_dtype), + timesteps=t.to(weight_dtype), + guidance=guidance_vec.to(weight_dtype),) + + loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if not args.disable_sampling and global_step % args.sample_every == 0: + if accelerator.is_main_process: + print(f"Sampling images for step {global_step}...") + sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device) + images = [] + for i, prompt in enumerate(args.sample_prompts): + result = sampler(prompt=prompt, + width=args.sample_width, + height=args.sample_height, + num_steps=args.sample_steps + ) + images.append(wandb.Image(result)) + print(f"Result for prompt #{i} is generated") + # result.save(f"{global_step}_prompt_{i}_res.png") + wandb.log({f"Results, step {global_step}": images}) + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + + accelerator.save_state(save_path) + unwrapped_model_state = accelerator.unwrap_model(dit).state_dict() + + # save checkpoint in safetensors format + lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k} + save_file( + lora_state_dict, + os.path.join(save_path, "lora.safetensors") + ) + + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + main()