diff --git a/README.md b/README.md index 4476e1e91ae2afa82ed1067d9bbbb71138da56f1..2059deffee178ccecb41278c420b085d98e950f7 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ --- title: Control Ability Arena -emoji: 🔥 -colorFrom: yellow -colorTo: gray +emoji: 🖼 +colorFrom: purple +colorTo: red sdk: gradio -sdk_version: 5.9.1 +sdk_version: 5.0.1 app_file: app.py pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b2338cb48086d3d0c3a430033fe831c5f79eb876 --- /dev/null +++ b/app.py @@ -0,0 +1,90 @@ +import gradio as gr +import os +from serve.gradio_web import * +from serve.gradio_web_bbox import build_side_by_side_bbox_ui_anony +from serve.leaderboard import build_leaderboard_tab, build_leaderboard_video_tab, build_leaderboard_contributor +from model.model_manager import ModelManager +from pathlib import Path +from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR + + +def make_default_md(): + link_color = "#1976D2" # This color should be clear in both light and dark mode + leaderboard_md = f""" + # 🏅 Control-Ability-Arena: ... + ### [Paper]... | [Twitter]... + - ⚡ For vision tasks, K-wise comparisons can provide much richer info but only take similar time as pairwise comparisons. + - 🎯 Well designed matchmaking algorithm can further save human efforts than random match pairing in normal Arena. + - 📈 Probabilistic modeling can obtain a faster and more stable convergence than Elo scoring system. + """ + + return leaderboard_md + + +def build_combine_demo(models): + with gr.Blocks( + title="Play with Open Vision Models", + theme=gr.themes.Default(), + css=block_css, + ) as demo: + + with gr.Blocks(): + md = make_default_md() + md_default = gr.Markdown(md, elem_id="default_leaderboard_markdown") + + + + with gr.Tabs() as tabs_combine: + # with gr.Tab("Image Generation", id=0): + # with gr.Tabs() as tabs_ig: + # # with gr.Tab("Generation Leaderboard", id=0): + # # build_leaderboard_tab() + # with gr.Tab("Generation Arena (battle)", id=1): + # build_side_by_side_ui_anony(models) + + with gr.Tab("BBox-to-Image Generation", id=0): + with gr.Tabs() as tabs_ig: + # with gr.Tab("Generation Leaderboard", id=0): + # build_leaderboard_tab() + with gr.Tab("Generation Arena (battle)", id=1): + build_side_by_side_bbox_ui_anony(models) + + # with gr.Tab("Contributor", id=2): + # build_leaderboard_contributor() + + return demo + + +def load_elo_results(elo_results_dir): + from collections import defaultdict + elo_results_file = defaultdict(lambda: None) + leaderboard_table_file = defaultdict(lambda: None) + + if elo_results_dir is not None: + elo_results_dir = Path(elo_results_dir) + elo_results_file = {} + leaderboard_table_file = {} + for file in elo_results_dir.glob('elo_results_*.pkl'): + if 't2i_generation' in file.name: + elo_results_file['t2i_generation'] = file + # else: + # raise ValueError(f"Unknown file name: {file.name}") + for file in elo_results_dir.glob('*_leaderboard.csv'): + if 't2i_generation' in file.name: + leaderboard_table_file['t2i_generation'] = file + # else: + # raise ValueError(f"Unknown file name: {file.name}") + + return elo_results_file, leaderboard_table_file + +if __name__ == "__main__": + server_port = int(SERVER_PORT) + root_path = ROOT_PATH + elo_results_dir = ELO_RESULTS_DIR + models = ModelManager() + + # elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir) + demo = build_combine_demo(models) + demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) + + # demo.launch(server_name="0.0.0.0", server_port=7860, root_path=ROOT_PATH) \ No newline at end of file diff --git a/ksort-logs/vote_log/gr_web_image_editing.log b/ksort-logs/vote_log/gr_web_image_editing.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ksort-logs/vote_log/gr_web_image_editing_multi.log b/ksort-logs/vote_log/gr_web_image_editing_multi.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ksort-logs/vote_log/gr_web_image_generation.log b/ksort-logs/vote_log/gr_web_image_generation.log new file mode 100644 index 0000000000000000000000000000000000000000..cc7f780c006ae4cfc2fc523191bb56e633f95aab --- /dev/null +++ b/ksort-logs/vote_log/gr_web_image_generation.log @@ -0,0 +1,811 @@ +2024-12-24 12:54:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 12:54:21 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 12:54:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 12:54:24 | ERROR | stderr | warnings.warn( +2024-12-24 12:54:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 12:54:24 | ERROR | stderr | warnings.warn( +2024-12-24 12:54:24 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 12:54:52 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:54:52 | INFO | stdout | len(layers) 1 +2024-12-24 12:54:52 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:54:55 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:54:55 | INFO | stdout | len(layers) 1 +2024-12-24 12:54:55 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:54:56 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:54:56 | INFO | stdout | len(layers) 1 +2024-12-24 12:54:56 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:54:58 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:54:58 | INFO | stdout | len(layers) 1 +2024-12-24 12:54:58 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:00 | INFO | stdout | +2024-12-24 12:55:00 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 12:55:00 | INFO | stdout | +2024-12-24 12:55:00 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 12:55:00 | INFO | stdout | +2024-12-24 12:55:00 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 12:55:00 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 12:55:00 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 12:55:04 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:04 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:04 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:06 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:06 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:06 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:09 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:09 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:09 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:14 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:14 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:14 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:21 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:21 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:21 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:21 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:21 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:21 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:25 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:25 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:25 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:26 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:26 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:26 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:27 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:27 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:27 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:29 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:29 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:29 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:31 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:31 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:31 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:35 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:35 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:35 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:41 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 12:55:41 | INFO | stdout | len(layers) 1 +2024-12-24 12:55:41 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 12:55:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 12:55:47 | ERROR | stderr | Output components: +2024-12-24 12:55:47 | ERROR | stderr | [textbox, button, button] +2024-12-24 12:55:47 | ERROR | stderr | Output values returned: +2024-12-24 12:55:47 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 12:55:47 | ERROR | stderr | warnings.warn( +2024-12-24 12:55:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 12:55:47 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 12:55:47 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 12:55:47 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 12:55:47 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 12:55:47 | ERROR | stderr | result = await self.call_function( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 12:55:47 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 12:55:47 | ERROR | stderr | return await anext(iterator) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 12:55:47 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 12:55:47 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 12:55:47 | ERROR | stderr | return await future +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 12:55:47 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 12:55:47 | ERROR | stderr | return next(iterator) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 12:55:47 | ERROR | stderr | response = next(iterator) +2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 793, in generate_igm_annoy +2024-12-24 12:55:47 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony +2024-12-24 12:55:47 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) +2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker +2024-12-24 12:55:47 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp() +2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp +2024-12-24 12:55:47 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client +2024-12-24 12:55:47 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect +2024-12-24 12:55:47 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port)) +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses +2024-12-24 12:55:47 | ERROR | stderr | addrinfos = socket.getaddrinfo( +2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo +2024-12-24 12:55:47 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags): +2024-12-24 12:55:47 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype +2024-12-24 12:55:47 | INFO | stdout | Rank +2024-12-24 13:17:15 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:17:15 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:17:15 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:17:15 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:17:15 | ERROR | stderr | +2024-12-24 13:17:15 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:17:15 | ERROR | stderr | +2024-12-24 13:17:15 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:17:15 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:17:15 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:17:15 | ERROR | stderr | self.block_thread() +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:17:15 | ERROR | stderr | self.server.close() +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:17:15 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:17:15 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:17:15 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:17:15 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:17:23 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:17:23 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:17:25 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 10 arguments for function functools.partial(, >), received 11. +2024-12-24 13:17:25 | ERROR | stderr | warnings.warn( +2024-12-24 13:17:25 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1011: UserWarning: Expected maximum 10 arguments for function functools.partial(, >), received 11. +2024-12-24 13:17:25 | ERROR | stderr | warnings.warn( +2024-12-24 13:17:25 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:17:40 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:17:40 | INFO | stdout | len(layers) 1 +2024-12-24 13:17:40 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:17:43 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:17:43 | INFO | stdout | len(layers) 1 +2024-12-24 13:17:43 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:17:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:17:47 | ERROR | stderr | Output components: +2024-12-24 13:17:47 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:17:47 | ERROR | stderr | Output values returned: +2024-12-24 13:17:47 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:17:47 | ERROR | stderr | warnings.warn( +2024-12-24 13:17:47 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:17:47 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:17:47 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:17:47 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:17:47 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:17:47 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:17:47 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:17:47 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:17:47 | ERROR | stderr | return await future +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:17:47 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:17:47 | ERROR | stderr | return next(iterator) +2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper +2024-12-24 13:17:47 | ERROR | stderr | iterator = f(*args, **kwargs) +2024-12-24 13:17:47 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given +2024-12-24 13:17:48 | INFO | stdout | Rank +2024-12-24 13:17:56 | INFO | stdout | +2024-12-24 13:17:56 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:17:56 | INFO | stdout | +2024-12-24 13:17:56 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:17:56 | INFO | stdout | +2024-12-24 13:17:56 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:17:56 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:17:56 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 13:18:00 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:18:00 | ERROR | stderr | Output components: +2024-12-24 13:18:00 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:18:00 | ERROR | stderr | Output values returned: +2024-12-24 13:18:00 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:18:00 | ERROR | stderr | warnings.warn( +2024-12-24 13:18:00 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:18:00 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:18:00 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:18:00 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:18:00 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:18:00 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:18:00 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:18:00 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:18:00 | ERROR | stderr | return await future +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:18:00 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:18:00 | ERROR | stderr | return next(iterator) +2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper +2024-12-24 13:18:00 | ERROR | stderr | iterator = f(*args, **kwargs) +2024-12-24 13:18:00 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given +2024-12-24 13:18:00 | INFO | stdout | Rank +2024-12-24 13:18:01 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:18:01 | ERROR | stderr | Output components: +2024-12-24 13:18:01 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:18:01 | ERROR | stderr | Output values returned: +2024-12-24 13:18:01 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:18:01 | ERROR | stderr | warnings.warn( +2024-12-24 13:18:01 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:18:01 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:18:01 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:18:01 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:18:01 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:18:01 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:18:01 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:18:01 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:18:01 | ERROR | stderr | return await future +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:18:01 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:18:01 | ERROR | stderr | return next(iterator) +2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper +2024-12-24 13:18:01 | ERROR | stderr | iterator = f(*args, **kwargs) +2024-12-24 13:18:01 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given +2024-12-24 13:18:01 | INFO | stdout | Rank +2024-12-24 13:32:32 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:32:32 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:32:32 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:32:32 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:32:32 | ERROR | stderr | +2024-12-24 13:32:32 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:32:32 | ERROR | stderr | +2024-12-24 13:32:32 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:32:32 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:32:32 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:32:32 | ERROR | stderr | self.block_thread() +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:32:32 | ERROR | stderr | self.server.close() +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:32:32 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:32:32 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:32:32 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:32:32 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:32:52 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:32:52 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:32:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:32:54 | ERROR | stderr | warnings.warn( +2024-12-24 13:32:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:32:54 | ERROR | stderr | warnings.warn( +2024-12-24 13:32:54 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:33:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:05 | ERROR | stderr | Output components: +2024-12-24 13:33:05 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:05 | ERROR | stderr | Output values returned: +2024-12-24 13:33:05 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:05 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:05 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:06 | INFO | stdout | Rank +2024-12-24 13:33:17 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:33:17 | INFO | stdout | len(layers) 1 +2024-12-24 13:33:17 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:33:18 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:33:18 | INFO | stdout | len(layers) 1 +2024-12-24 13:33:18 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:33:19 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:33:19 | INFO | stdout | len(layers) 1 +2024-12-24 13:33:19 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:33:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:20 | ERROR | stderr | Output components: +2024-12-24 13:33:20 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:20 | ERROR | stderr | Output values returned: +2024-12-24 13:33:20 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:20 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:20 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:20 | INFO | stdout | Rank +2024-12-24 13:33:25 | INFO | stdout | +2024-12-24 13:33:25 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:33:25 | INFO | stdout | +2024-12-24 13:33:25 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:33:25 | INFO | stdout | +2024-12-24 13:33:25 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:33:25 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:33:25 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 13:33:29 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:29 | ERROR | stderr | Output components: +2024-12-24 13:33:29 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:29 | ERROR | stderr | Output values returned: +2024-12-24 13:33:29 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:29 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:30 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:30 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:30 | INFO | stdout | Rank +2024-12-24 13:33:31 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:31 | ERROR | stderr | Output components: +2024-12-24 13:33:31 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:31 | ERROR | stderr | Output values returned: +2024-12-24 13:33:31 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:31 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:32 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:32 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:32 | INFO | stdout | Rank +2024-12-24 13:33:33 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:33 | ERROR | stderr | Output components: +2024-12-24 13:33:33 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:33 | ERROR | stderr | Output values returned: +2024-12-24 13:33:33 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:33 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:33 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:33 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:34 | INFO | stdout | Rank +2024-12-24 13:33:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:34 | ERROR | stderr | Output components: +2024-12-24 13:33:34 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:34 | ERROR | stderr | Output values returned: +2024-12-24 13:33:34 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:34 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:34 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:34 | INFO | stdout | Rank +2024-12-24 13:33:35 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:33:35 | ERROR | stderr | Output components: +2024-12-24 13:33:35 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:33:35 | ERROR | stderr | Output values returned: +2024-12-24 13:33:35 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:33:35 | ERROR | stderr | warnings.warn( +2024-12-24 13:33:35 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:33:35 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:33:35 | INFO | stdout | Rank +2024-12-24 13:33:52 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:33:52 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:33:52 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:33:52 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:33:52 | ERROR | stderr | +2024-12-24 13:33:52 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:33:52 | ERROR | stderr | +2024-12-24 13:33:52 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:33:52 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:33:52 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:33:52 | ERROR | stderr | self.block_thread() +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:33:52 | ERROR | stderr | self.server.close() +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:33:52 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:33:52 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:33:52 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:33:52 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:33:53 | ERROR | stderr | Exception ignored in: +2024-12-24 13:33:53 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:33:53 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1567, in _shutdown +2024-12-24 13:33:53 | ERROR | stderr | lock.acquire() +2024-12-24 13:33:53 | ERROR | stderr | KeyboardInterrupt: +2024-12-24 13:34:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:34:05 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:34:07 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:34:07 | ERROR | stderr | warnings.warn( +2024-12-24 13:34:07 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:34:07 | ERROR | stderr | warnings.warn( +2024-12-24 13:34:07 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:34:18 | INFO | stdout | +2024-12-24 13:34:18 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:34:18 | INFO | stdout | +2024-12-24 13:34:18 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:34:18 | INFO | stdout | +2024-12-24 13:34:18 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:34:18 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:34:18 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 13:34:25 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:34:25 | INFO | stdout | len(layers) 1 +2024-12-24 13:34:25 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:34:26 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:34:26 | INFO | stdout | len(layers) 1 +2024-12-24 13:34:26 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:34:28 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:34:28 | ERROR | stderr | Output components: +2024-12-24 13:34:28 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:34:28 | ERROR | stderr | Output values returned: +2024-12-24 13:34:28 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:34:28 | ERROR | stderr | warnings.warn( +2024-12-24 13:34:29 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:34:29 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:34:29 | INFO | stdout | Rank +2024-12-24 13:43:01 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:43:01 | ERROR | stderr | Output components: +2024-12-24 13:43:01 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:43:01 | ERROR | stderr | Output values returned: +2024-12-24 13:43:01 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:43:01 | ERROR | stderr | warnings.warn( +2024-12-24 13:43:02 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:43:02 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:43:02 | INFO | stdout | Rank +2024-12-24 13:43:10 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:43:10 | INFO | stdout | len(layers) 1 +2024-12-24 13:43:10 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:43:12 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:43:12 | INFO | stdout | len(layers) 1 +2024-12-24 13:43:12 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:43:14 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:43:14 | INFO | stdout | len(layers) 1 +2024-12-24 13:43:14 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:43:17 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:43:17 | ERROR | stderr | Output components: +2024-12-24 13:43:17 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:43:17 | ERROR | stderr | Output values returned: +2024-12-24 13:43:17 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:43:17 | ERROR | stderr | warnings.warn( +2024-12-24 13:43:17 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:43:17 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:43:17 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:43:17 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:43:17 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:43:17 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:43:17 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:43:17 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:43:17 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:43:17 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:43:17 | ERROR | stderr | return await future +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:43:17 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:43:17 | ERROR | stderr | return next(iterator) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 13:43:17 | ERROR | stderr | response = next(iterator) +2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy +2024-12-24 13:43:17 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony +2024-12-24 13:43:17 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) +2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker +2024-12-24 13:43:17 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp() +2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp +2024-12-24 13:43:17 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client +2024-12-24 13:43:17 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect +2024-12-24 13:43:17 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port)) +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses +2024-12-24 13:43:17 | ERROR | stderr | addrinfos = socket.getaddrinfo( +2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo +2024-12-24 13:43:17 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags): +2024-12-24 13:43:17 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype +2024-12-24 13:43:17 | INFO | stdout | Rank +2024-12-24 13:44:00 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:44:00 | INFO | stdout | len(layers) 1 +2024-12-24 13:44:00 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:44:01 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:44:01 | INFO | stdout | len(layers) 1 +2024-12-24 13:44:01 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:44:08 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values. +2024-12-24 13:44:08 | ERROR | stderr | Output components: +2024-12-24 13:44:08 | ERROR | stderr | [textbox, button, button] +2024-12-24 13:44:08 | ERROR | stderr | Output values returned: +2024-12-24 13:44:08 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}] +2024-12-24 13:44:08 | ERROR | stderr | warnings.warn( +2024-12-24 13:44:08 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:44:08 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:44:08 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:44:08 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:44:08 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:44:08 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:44:08 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:44:08 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:44:08 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:44:08 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:44:08 | ERROR | stderr | return await future +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:44:08 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:44:08 | ERROR | stderr | return next(iterator) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 13:44:08 | ERROR | stderr | response = next(iterator) +2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy +2024-12-24 13:44:08 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony +2024-12-24 13:44:08 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) +2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker +2024-12-24 13:44:08 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp() +2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp +2024-12-24 13:44:08 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client +2024-12-24 13:44:08 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect +2024-12-24 13:44:08 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port)) +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses +2024-12-24 13:44:08 | ERROR | stderr | addrinfos = socket.getaddrinfo( +2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo +2024-12-24 13:44:08 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags): +2024-12-24 13:44:08 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype +2024-12-24 13:44:09 | INFO | stdout | Rank +2024-12-24 13:45:55 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:45:55 | INFO | stdout | len(layers) 1 +2024-12-24 13:45:55 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:45:57 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:45:57 | INFO | stdout | len(layers) 1 +2024-12-24 13:45:57 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:45:59 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:45:59 | INFO | stdout | len(layers) 1 +2024-12-24 13:45:59 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:52:19 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:52:20 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:52:20 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:52:20 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:52:20 | ERROR | stderr | +2024-12-24 13:52:20 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:52:20 | ERROR | stderr | +2024-12-24 13:52:20 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:52:20 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:52:20 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:52:20 | ERROR | stderr | self.block_thread() +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:52:20 | ERROR | stderr | self.server.close() +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:52:20 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:52:20 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:52:20 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:52:20 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:52:32 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:52:32 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:52:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:52:34 | ERROR | stderr | warnings.warn( +2024-12-24 13:52:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:52:34 | ERROR | stderr | warnings.warn( +2024-12-24 13:52:35 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:52:42 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:52:42 | INFO | stdout | len(layers) 1 +2024-12-24 13:52:42 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:52:42 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:52:42 | INFO | stdout | len(layers) 1 +2024-12-24 13:52:42 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:52:44 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:52:44 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:52:44 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:52:44 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:52:44 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:52:44 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:52:44 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:52:44 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:52:44 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:52:44 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:52:44 | ERROR | stderr | return await future +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:52:44 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:52:44 | ERROR | stderr | return next(iterator) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 13:52:44 | ERROR | stderr | response = next(iterator) +2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy +2024-12-24 13:52:44 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony +2024-12-24 13:52:44 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) +2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker +2024-12-24 13:52:44 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp() +2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp +2024-12-24 13:52:44 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client +2024-12-24 13:52:44 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect +2024-12-24 13:52:44 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port)) +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses +2024-12-24 13:52:44 | ERROR | stderr | addrinfos = socket.getaddrinfo( +2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo +2024-12-24 13:52:44 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags): +2024-12-24 13:52:44 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype +2024-12-24 13:52:44 | INFO | stdout | Rank +2024-12-24 13:53:06 | INFO | stdout | +2024-12-24 13:53:06 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:53:06 | INFO | stdout | +2024-12-24 13:53:06 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:53:06 | INFO | stdout | +2024-12-24 13:53:06 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:53:06 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:53:06 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 13:56:11 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:56:12 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:56:12 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:56:12 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:56:12 | ERROR | stderr | +2024-12-24 13:56:12 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:56:12 | ERROR | stderr | +2024-12-24 13:56:12 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:56:12 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:56:12 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:56:12 | ERROR | stderr | self.block_thread() +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:56:12 | ERROR | stderr | self.server.close() +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:56:12 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:56:12 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:56:12 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:56:12 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:56:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:56:24 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:56:26 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:56:26 | ERROR | stderr | warnings.warn( +2024-12-24 13:56:26 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:56:26 | ERROR | stderr | warnings.warn( +2024-12-24 13:56:26 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:56:36 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:56:36 | INFO | stdout | len(layers) 1 +2024-12-24 13:56:36 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:56:36 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:56:36 | INFO | stdout | len(layers) 1 +2024-12-24 13:56:36 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:56:41 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:56:41 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:56:41 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:56:41 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:56:41 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:56:41 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:56:41 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:56:41 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:56:41 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:56:41 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:56:41 | ERROR | stderr | return await future +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:56:41 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:56:41 | ERROR | stderr | return next(iterator) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 13:56:41 | ERROR | stderr | response = next(iterator) +2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy +2024-12-24 13:56:41 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony +2024-12-24 13:56:41 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) +2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker +2024-12-24 13:56:41 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp() +2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp +2024-12-24 13:56:41 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client +2024-12-24 13:56:41 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect +2024-12-24 13:56:41 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port)) +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses +2024-12-24 13:56:41 | ERROR | stderr | addrinfos = socket.getaddrinfo( +2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo +2024-12-24 13:56:41 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags): +2024-12-24 13:56:41 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype +2024-12-24 13:56:41 | INFO | stdout | Rank +2024-12-24 13:57:05 | INFO | stdout | +2024-12-24 13:57:05 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:57:05 | INFO | stdout | +2024-12-24 13:57:05 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:57:05 | INFO | stdout | +2024-12-24 13:57:05 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:57:05 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:57:05 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2024-12-24 13:58:10 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-12-24 13:58:10 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread +2024-12-24 13:58:10 | ERROR | stderr | time.sleep(0.1) +2024-12-24 13:58:10 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:58:10 | ERROR | stderr | +2024-12-24 13:58:10 | ERROR | stderr | During handling of the above exception, another exception occurred: +2024-12-24 13:58:10 | ERROR | stderr | +2024-12-24 13:58:10 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:58:10 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in +2024-12-24 13:58:10 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True) +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch +2024-12-24 13:58:10 | ERROR | stderr | self.block_thread() +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread +2024-12-24 13:58:10 | ERROR | stderr | self.server.close() +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close +2024-12-24 13:58:10 | ERROR | stderr | self.thread.join(timeout=5) +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join +2024-12-24 13:58:10 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0)) +2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock +2024-12-24 13:58:10 | ERROR | stderr | if lock.acquire(block, timeout): +2024-12-24 13:58:10 | ERROR | stderr | KeyboardInterrupt +2024-12-24 13:58:11 | ERROR | stderr | Exception ignored in: +2024-12-24 13:58:11 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:58:11 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1567, in _shutdown +2024-12-24 13:58:11 | ERROR | stderr | lock.acquire() +2024-12-24 13:58:11 | ERROR | stderr | KeyboardInterrupt: +2024-12-24 13:58:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2024-12-24 13:58:20 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2024-12-24 13:58:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:58:21 | ERROR | stderr | warnings.warn( +2024-12-24 13:58:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2024-12-24 13:58:21 | ERROR | stderr | warnings.warn( +2024-12-24 13:58:22 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2024-12-24 13:58:32 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:58:32 | INFO | stdout | len(layers) 1 +2024-12-24 13:58:32 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:58:33 | INFO | stdout | background.shape (600, 600, 4) +2024-12-24 13:58:33 | INFO | stdout | len(layers) 1 +2024-12-24 13:58:33 | INFO | stdout | composite.shape (600, 600, 4) +2024-12-24 13:58:37 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2024-12-24 13:58:37 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2024-12-24 13:58:37 | INFO | stdout | [0] +2024-12-24 13:58:37 | INFO | stdout | ['replicate_SDXL_text2image'] +2024-12-24 13:58:37 | ERROR | stderr | Traceback (most recent call last): +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events +2024-12-24 13:58:37 | ERROR | stderr | response = await route_utils.call_process_api( +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api +2024-12-24 13:58:37 | ERROR | stderr | output = await app.get_blocks().process_api( +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api +2024-12-24 13:58:37 | ERROR | stderr | result = await self.call_function( +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function +2024-12-24 13:58:37 | ERROR | stderr | prediction = await utils.async_iteration(iterator) +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration +2024-12-24 13:58:37 | ERROR | stderr | return await anext(iterator) +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__ +2024-12-24 13:58:37 | ERROR | stderr | return await anyio.to_thread.run_sync( +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync +2024-12-24 13:58:37 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread( +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread +2024-12-24 13:58:37 | ERROR | stderr | return await future +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run +2024-12-24 13:58:37 | ERROR | stderr | result = context.run(func, *args) +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async +2024-12-24 13:58:37 | ERROR | stderr | return next(iterator) +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper +2024-12-24 13:58:37 | ERROR | stderr | response = next(iterator) +2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy +2024-12-24 13:58:37 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) +2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 104, in generate_image_b2i_parallel_anony +2024-12-24 13:58:37 | ERROR | stderr | results = [future.result() for future in futures] +2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 104, in +2024-12-24 13:58:37 | ERROR | stderr | results = [future.result() for future in futures] +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/_base.py", line 451, in result +2024-12-24 13:58:37 | ERROR | stderr | return self.__get_result() +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result +2024-12-24 13:58:37 | ERROR | stderr | raise self._exception +2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/thread.py", line 58, in run +2024-12-24 13:58:37 | ERROR | stderr | result = self.fn(*self.args, **self.kwargs) +2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 87, in generate_image_b2i +2024-12-24 13:58:37 | ERROR | stderr | return result +2024-12-24 13:58:37 | ERROR | stderr | UnboundLocalError: local variable 'result' referenced before assignment +2024-12-24 13:58:37 | INFO | stdout | Rank +2024-12-24 13:58:53 | INFO | stdout | +2024-12-24 13:58:53 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2024-12-24 13:58:53 | INFO | stdout | +2024-12-24 13:58:53 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2024-12-24 13:58:53 | INFO | stdout | +2024-12-24 13:58:53 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2024-12-24 13:58:53 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2024-12-24 13:58:53 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio diff --git a/ksort-logs/vote_log/gr_web_image_generation_multi.log b/ksort-logs/vote_log/gr_web_image_generation_multi.log new file mode 100644 index 0000000000000000000000000000000000000000..0fc298724488fc0a7d606979451cea8673f8e0af --- /dev/null +++ b/ksort-logs/vote_log/gr_web_image_generation_multi.log @@ -0,0 +1,6 @@ +2024-12-24 12:55:47 | INFO | gradio_web_server_image_generation_multi | generate. ip: None +2024-12-24 13:43:17 | INFO | gradio_web_server_image_generation_multi | generate. ip: None +2024-12-24 13:44:08 | INFO | gradio_web_server_image_generation_multi | generate. ip: None +2024-12-24 13:52:44 | INFO | gradio_web_server_image_generation_multi | generate. ip: None +2024-12-24 13:56:41 | INFO | gradio_web_server_image_generation_multi | generate. ip: None +2024-12-24 13:58:37 | INFO | gradio_web_server_image_generation_multi | generate. ip: None diff --git a/ksort-logs/vote_log/gr_web_video_generation.log b/ksort-logs/vote_log/gr_web_video_generation.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ksort-logs/vote_log/gr_web_video_generation_multi.log b/ksort-logs/vote_log/gr_web_video_generation_multi.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f5b229a16bb4eff72e4b37b394efc30b0125702 Binary files /dev/null and b/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/__pycache__/__init__.cpython-312.pyc b/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad06f93d72ebfa3a41679b658acf6eae226e0570 Binary files /dev/null and b/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03b61469632c10ae29ea3ddbe46eab2dc30e694c Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/__pycache__/matchmaker.cpython-310.pyc b/model/__pycache__/matchmaker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5ce11c0fe69fa1808055e98bbae9affbfff2dd Binary files /dev/null and b/model/__pycache__/matchmaker.cpython-310.pyc differ diff --git a/model/__pycache__/model_manager.cpython-310.pyc b/model/__pycache__/model_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a9a7f465d4fb33aff3fd455185a90194b70a3d4 Binary files /dev/null and b/model/__pycache__/model_manager.cpython-310.pyc differ diff --git a/model/__pycache__/model_registry.cpython-310.pyc b/model/__pycache__/model_registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba5e7e48835a1824bce700df61d6324600ee45b Binary files /dev/null and b/model/__pycache__/model_registry.cpython-310.pyc differ diff --git a/model/__pycache__/model_registry.cpython-312.pyc b/model/__pycache__/model_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c08109373e8997d9111a282af90d660609d0915d Binary files /dev/null and b/model/__pycache__/model_registry.cpython-312.pyc differ diff --git a/model/__pycache__/model_registry.cpython-39.pyc b/model/__pycache__/model_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df74ac1274e8d05b1d6730593174da3bb162940d Binary files /dev/null and b/model/__pycache__/model_registry.cpython-39.pyc differ diff --git a/model/matchmaker.py b/model/matchmaker.py new file mode 100644 index 0000000000000000000000000000000000000000..1218181fcc40f5ff906526a6582e9223f9ac11d3 --- /dev/null +++ b/model/matchmaker.py @@ -0,0 +1,126 @@ +import numpy as np +import json +from trueskill import TrueSkill +import paramiko +import io, os +import sys +import random + +sys.path.append('../') +from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL +trueskill_env = TrueSkill() + +ssh_matchmaker_client = None +sftp_matchmaker_client = None + +def create_ssh_matchmaker_client(server, port, user, password): + global ssh_matchmaker_client, sftp_matchmaker_client + ssh_matchmaker_client = paramiko.SSHClient() + ssh_matchmaker_client.load_system_host_keys() + ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_matchmaker_client.connect(server, port, user, password) + + transport = ssh_matchmaker_client.get_transport() + transport.set_keepalive(60) + + sftp_matchmaker_client = ssh_matchmaker_client.open_sftp() + + +def is_connected(): + global ssh_matchmaker_client, sftp_matchmaker_client + if ssh_matchmaker_client is None or sftp_matchmaker_client is None: + return False + if not ssh_matchmaker_client.get_transport().is_active(): + return False + try: + sftp_matchmaker_client.listdir('.') + except Exception as e: + print(f"Error checking SFTP connection: {e}") + return False + return True + + +def ucb_score(trueskill_diff, t, n): + exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) + ucb = -trueskill_diff + 1.0 * exploration_term + return ucb + + +def update_trueskill(ratings, ranks): + new_ratings = trueskill_env.rate(ratings, ranks) + return new_ratings + + +def serialize_rating(rating): + return {'mu': rating.mu, 'sigma': rating.sigma} + + +def deserialize_rating(rating_dict): + return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) + + +def save_json_via_sftp(ratings, comparison_counts, total_comparisons): + global sftp_matchmaker_client + if not is_connected(): + create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + data = { + 'ratings': [serialize_rating(r) for r in ratings], + 'comparison_counts': comparison_counts.tolist(), + 'total_comparisons': total_comparisons + } + json_data = json.dumps(data) + with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f: + f.write(json_data) + + +def load_json_via_sftp(): + global sftp_matchmaker_client + if not is_connected(): + create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f: + data = json.load(f) + ratings = [deserialize_rating(r) for r in data['ratings']] + comparison_counts = np.array(data['comparison_counts']) + total_comparisons = data['total_comparisons'] + return ratings, comparison_counts, total_comparisons + + +class RunningPivot(object): + running_pivot = [] + + +def matchmaker(num_players, k_group=4, not_run=[]): + trueskill_env = TrueSkill() + + ratings, comparison_counts, total_comparisons = load_json_via_sftp() + + ratings = ratings[:num_players] + comparison_counts = comparison_counts[:num_players, :num_players] + + # Randomly select a player + # selected_player = np.random.randint(0, num_players) + comparison_counts[RunningPivot.running_pivot, :] = float('inf') + comparison_counts[not_run, :] = float('inf') + selected_player = np.argmin(comparison_counts.sum(axis=1)) + + RunningPivot.running_pivot.append(selected_player) + RunningPivot.running_pivot = RunningPivot.running_pivot[-5:] + print(RunningPivot.running_pivot) + + selected_trueskill_score = trueskill_env.expose(ratings[selected_player]) + trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings]) + trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score) + n = comparison_counts[selected_player] + ucb_scores = ucb_score(trueskill_diff, total_comparisons, n) + + # Exclude self, select opponent with highest UCB score + ucb_scores[selected_player] = -float('inf') + ucb_scores[not_run] = -float('inf') + opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist() + + # Group players + model_ids = [selected_player] + opponents + + random.shuffle(model_ids) + + return model_ids diff --git a/model/matchmaker_video.py b/model/matchmaker_video.py new file mode 100644 index 0000000000000000000000000000000000000000..bb97e4cae98ba8805e9c4e405c49c55f18d6df9d --- /dev/null +++ b/model/matchmaker_video.py @@ -0,0 +1,136 @@ +import numpy as np +import json +from trueskill import TrueSkill +import paramiko +import io, os +import sys +import random + +sys.path.append('../') +from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL +trueskill_env = TrueSkill() + +ssh_matchmaker_client = None +sftp_matchmaker_client = None + + +def create_ssh_matchmaker_client(server, port, user, password): + global ssh_matchmaker_client, sftp_matchmaker_client + ssh_matchmaker_client = paramiko.SSHClient() + ssh_matchmaker_client.load_system_host_keys() + ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_matchmaker_client.connect(server, port, user, password) + + transport = ssh_matchmaker_client.get_transport() + transport.set_keepalive(60) + + sftp_matchmaker_client = ssh_matchmaker_client.open_sftp() + + +def is_connected(): + global ssh_matchmaker_client, sftp_matchmaker_client + if ssh_matchmaker_client is None or sftp_matchmaker_client is None: + return False + if not ssh_matchmaker_client.get_transport().is_active(): + return False + try: + sftp_matchmaker_client.listdir('.') + except Exception as e: + print(f"Error checking SFTP connection: {e}") + return False + return True + + +def ucb_score(trueskill_diff, t, n): + exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) + ucb = -trueskill_diff + 1.0 * exploration_term + return ucb + + +def update_trueskill(ratings, ranks): + new_ratings = trueskill_env.rate(ratings, ranks) + return new_ratings + + +def serialize_rating(rating): + return {'mu': rating.mu, 'sigma': rating.sigma} + + +def deserialize_rating(rating_dict): + return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) + + +def save_json_via_sftp(ratings, comparison_counts, total_comparisons): + global sftp_matchmaker_client + if not is_connected(): + create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + data = { + 'ratings': [serialize_rating(r) for r in ratings], + 'comparison_counts': comparison_counts.tolist(), + 'total_comparisons': total_comparisons + } + json_data = json.dumps(data) + with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f: + f.write(json_data) + + +def load_json_via_sftp(): + global sftp_matchmaker_client + if not is_connected(): + create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'r') as f: + data = json.load(f) + ratings = [deserialize_rating(r) for r in data['ratings']] + comparison_counts = np.array(data['comparison_counts']) + total_comparisons = data['total_comparisons'] + return ratings, comparison_counts, total_comparisons + + +def matchmaker_video(num_players, k_group=4): + trueskill_env = TrueSkill() + + ratings, comparison_counts, total_comparisons = load_json_via_sftp() + + ratings = ratings[:num_players] + comparison_counts = comparison_counts[:num_players, :num_players] + + selected_player = np.argmin(comparison_counts.sum(axis=1)) + + selected_trueskill_score = trueskill_env.expose(ratings[selected_player]) + trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings]) + trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score) + n = comparison_counts[selected_player] + ucb_scores = ucb_score(trueskill_diff, total_comparisons, n) + + # Exclude self, select opponent with highest UCB score + ucb_scores[selected_player] = -float('inf') + + excluded_players_1 = [7, 10] + excluded_players_2 = [6, 8, 9] + excluded_players = excluded_players_1 + excluded_players_2 + if selected_player in excluded_players_1: + for player in excluded_players: + ucb_scores[player] = -float('inf') + if selected_player in excluded_players_2: + for player in excluded_players_1: + ucb_scores[player] = -float('inf') + else: + excluded_ucb_scores = {player: ucb_scores[player] for player in excluded_players} + max_player = max(excluded_ucb_scores, key=excluded_ucb_scores.get) + if max_player in excluded_players_1: + for player in excluded_players: + if player != max_player: + ucb_scores[player] = -float('inf') + else: + for player in excluded_players_1: + ucb_scores[player] = -float('inf') + + + opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist() + + # Group players + model_ids = [selected_player] + opponents + + random.shuffle(model_ids) + + return model_ids diff --git a/model/model_manager.py b/model/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ad039fd5523a6eff5200c2d973c9ab2157f4a816 --- /dev/null +++ b/model/model_manager.py @@ -0,0 +1,239 @@ +import concurrent.futures +import random +import gradio as gr +import requests, os +import io, base64, json +import spaces +import torch +from PIL import Image +from openai import OpenAI +from .models import IMAGE_GENERATION_MODELS, VIDEO_GENERATION_MODELS, B2I_MODELS, load_pipeline +from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt +from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE + + +class ModelManager: + def __init__(self): + self.model_ig_list = IMAGE_GENERATION_MODELS + self.model_ie_list = [] #IMAGE_EDITION_MODELS + self.model_vg_list = VIDEO_GENERATION_MODELS + self.model_b2i_list = B2I_MODELS + self.loaded_models = {} + + def load_model_pipe(self, model_name): + if not model_name in self.loaded_models: + pipe = load_pipeline(model_name) + self.loaded_models[model_name] = pipe + else: + pipe = self.loaded_models[model_name] + return pipe + + @spaces.GPU(duration=120) + def generate_image_ig(self, prompt, model_name): + pipe = self.load_model_pipe(model_name) + if 'Stable-cascade' not in model_name: + result = pipe(prompt=prompt).images[0] + else: + prior, decoder = pipe + prior.enable_model_cpu_offload() + prior_output = prior( + prompt=prompt, + height=512, + width=512, + negative_prompt='', + guidance_scale=4.0, + num_images_per_prompt=1, + num_inference_steps=20 + ) + decoder.enable_model_cpu_offload() + result = decoder( + image_embeddings=prior_output.image_embeddings.to(torch.float16), + prompt=prompt, + negative_prompt='', + guidance_scale=0.0, + output_type="pil", + num_inference_steps=10 + ).images[0] + return result + + def generate_image_ig_api(self, prompt, model_name): + pipe = self.load_model_pipe(model_name) + result = pipe(prompt=prompt) + return result + + def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D): + if model_A == "" and model_B == "" and model_C == "" and model_D == "": + from .matchmaker import matchmaker + not_run = [20,21,22, 25,26, 30] #12,13,14,15,16,17,18,19,20,21,22, #23,24, + model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) + print(model_ids) + model_names = [self.model_ig_list[i] for i in model_ids] + print(model_names) + else: + model_names = [model_A, model_B, model_C, model_D] + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface") + else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] + results = [future.result() for future in futures] + + return results[0], results[1], results[2], results[3], \ + model_names[0], model_names[1], model_names[2], model_names[3] + + def generate_image_b2i(self, prompt, grounding_instruction, bbox, model_name): + pipe = self.load_model_pipe(model_name) + if model_name == "local_MIGC_b2i": + from model_bbox.MIGC.inference_single_image import inference_image + result = inference_image(pipe, prompt, grounding_instruction, bbox) + elif model_name == "huggingface_ReCo_b2i": + from model_bbox.ReCo.inference import inference_image + result = inference_image(pipe, prompt, grounding_instruction, bbox) + return result + + + def generate_image_b2i_parallel_anony(self, prompt, grounding_instruction, bbox, model_A, model_B, model_C, model_D): + if model_A == "" and model_B == "" and model_C == "" and model_D == "": + from .matchmaker import matchmaker + not_run = [] #12,13,14,15,16,17,18,19,20,21,22, #23,24, + # model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) + model_ids = [0, 1] + print(model_ids) + model_names = [self.model_b2i_list[i] for i in model_ids] + print(model_names) + else: + model_names = [model_A, model_B, model_C, model_D] + + from concurrent.futures import ProcessPoolExecutor + with ProcessPoolExecutor() as executor: + futures = [executor.submit(self.generate_image_b2i, prompt, grounding_instruction, bbox, model) + for model in model_names] + results = [future.result() for future in futures] + + # with concurrent.futures.ThreadPoolExecutor() as executor: + # futures = [executor.submit(self.generate_image_b2i, prompt, grounding_instruction, bbox, model) for model in model_names] + # results = [future.result() for future in futures] + + blank_image = None + final_results = [] + for i in range(4): + if i < len(model_ids): + # 如果是有效模型,返回相应的生成结果 + final_results.append(results[i]) + else: + # 如果没有生成结果,则返回空白图像 + final_results.append(blank_image) + final_model_names = [] + for i in range(4): + if i < len(model_ids): + final_model_names.append(model_names[i]) + else: + final_model_names.append("") + + return final_results[0], final_results[1], final_results[2], final_results[3], \ + final_model_names[0], final_model_names[1], final_model_names[2], final_model_names[3] + + def generate_image_ig_cache_anony(self, model_A, model_B, model_C, model_D): + if model_A == "" and model_B == "" and model_C == "" and model_D == "": + from .matchmaker import matchmaker + not_run = [20,21,22] + model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run) + print(model_ids) + model_names = [self.model_ig_list[i] for i in model_ids] + print(model_names) + else: + model_names = [model_A, model_B, model_C, model_D] + + root_dir = SSH_CACHE_IMAGE + local_dir = "./cache_image" + if not os.path.exists(local_dir): + os.makedirs(local_dir) + prompt, results = get_ssh_random_image_prompt(root_dir, local_dir, model_names) + + return results[0], results[1], results[2], results[3], \ + model_names[0], model_names[1], model_names[2], model_names[3], prompt + + def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D): + if model_A == "" and model_B == "" and model_C == "" and model_D == "": + # model_names = random.sample([model for model in self.model_vg_list], 4) + + from .matchmaker_video import matchmaker_video + model_ids = matchmaker_video(num_players=len(self.model_vg_list)) + print(model_ids) + model_names = [self.model_vg_list[i] for i in model_ids] + print(model_names) + else: + model_names = [model_A, model_B, model_C, model_D] + + root_dir = SSH_CACHE_OPENSOURCE + for name in model_names: + if "Runway-Gen3" in name or "Runway-Gen2" in name or "Pika-v1.0" in name: + root_dir = SSH_CACHE_ADVANCE + elif "Pika-beta" in name: + root_dir = SSH_CACHE_PIKA + elif "Sora" in name and "OpenSora" not in name: + root_dir = SSH_CACHE_SORA + + local_dir = "./cache_video" + if not os.path.exists(local_dir): + os.makedirs(local_dir) + prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names) + cache_dir = local_dir + + return results[0], results[1], results[2], results[3], \ + model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir + + def generate_image_ig_museum_parallel_anony(self, model_A, model_B, model_C, model_D): + if model_A == "" and model_B == "" and model_C == "" and model_D == "": + # model_names = random.sample([model for model in self.model_ig_list], 4) + + from .matchmaker import matchmaker + model_ids = matchmaker(num_players=len(self.model_ig_list)) + print(model_ids) + model_names = [self.model_ig_list[i] for i in model_ids] + print(model_names) + else: + model_names = [model_A, model_B, model_C, model_D] + + prompt = get_random_mscoco_prompt() + print(prompt) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface") + else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] + results = [future.result() for future in futures] + + return results[0], results[1], results[2], results[3], \ + model_names[0], model_names[1], model_names[2], model_names[3], prompt + + def generate_image_ig_parallel(self, prompt, model_A, model_B): + model_names = [model_A, model_B] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub") + else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names] + results = [future.result() for future in futures] + return results[0], results[1] + + @spaces.GPU(duration=200) + def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name): + pipe = self.load_model_pipe(model_name) + result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct) + return result + + def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): + model_names = [model_A, model_B] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, + model) for model in model_names] + results = [future.result() for future in futures] + return results[0], results[1] + + def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B): + if model_A == "" and model_B == "": + model_names = random.sample([model for model in self.model_ie_list], 2) + else: + model_names = [model_A, model_B] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names] + results = [future.result() for future in futures] + return results[0], results[1], model_names[0], model_names[1] diff --git a/model/model_registry.py b/model/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d9f0c2a8b84e851d59581b709ffe2ba99050bc --- /dev/null +++ b/model/model_registry.py @@ -0,0 +1,70 @@ +from collections import namedtuple +from typing import List + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) +model_info = {} + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + +def get_model_info(name: str) -> ModelInfo: + if name in model_info: + return model_info[name] + else: + # To fix this, please use `register_model_info` to register your model + return ModelInfo( + name, "", "Register the description at fastchat/model/model_registry.py" + ) + +def get_model_description_md(model_list): + model_description_md = """ +| | | | | | | | | | | | +| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(model_list): + model_source, model_name, model_type = name.split("_") + minfo = get_model_info(model_name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + # one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + one_model_md = f"{minfo.simple_name}" + + if ct % 11 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 11 == 10: + model_description_md += "\n" + ct += 1 + return model_description_md + +def get_video_model_description_md(model_list): + model_description_md = """ +| | | | | | | +| ---- | ---- | ---- | ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(model_list): + model_source, model_name, model_type = name.split("_") + minfo = get_model_info(model_name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + # one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + one_model_md = f"{minfo.simple_name}" + + if ct % 7 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 7 == 6: + model_description_md += "\n" + ct += 1 + return model_description_md diff --git a/model/models/__init__.py b/model/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a939eaa60adbce280f6382be0bb8e19c34da24d3 --- /dev/null +++ b/model/models/__init__.py @@ -0,0 +1,83 @@ +from .huggingface_models import load_huggingface_model +from .replicate_api_models import load_replicate_model +from .openai_api_models import load_openai_model +from .other_api_models import load_other_model +from .local_models import load_local_model + + +IMAGE_GENERATION_MODELS = [ + 'replicate_SDXL_text2image', + 'replicate_SD-v3.0_text2image', + 'replicate_SD-v2.1_text2image', + 'replicate_SD-v1.5_text2image', + 'replicate_SDXL-Lightning_text2image', + 'replicate_Kandinsky-v2.0_text2image', + 'replicate_Kandinsky-v2.2_text2image', + 'replicate_Proteus-v0.2_text2image', + 'replicate_Playground-v2.0_text2image', + 'replicate_Playground-v2.5_text2image', + 'replicate_Dreamshaper-xl-turbo_text2image', + 'replicate_SDXL-Deepcache_text2image', + 'replicate_Openjourney-v4_text2image', + 'replicate_LCM-v1.5_text2image', + 'replicate_Realvisxl-v3.0_text2image', + 'replicate_Realvisxl-v2.0_text2image', + 'replicate_Pixart-Sigma_text2image', + 'replicate_SSD-1b_text2image', + 'replicate_Open-Dalle-v1.1_text2image', + 'replicate_Deepfloyd-IF_text2image', + 'huggingface_SD-turbo_text2image', + 'huggingface_SDXL-turbo_text2image', + 'huggingface_Stable-cascade_text2image', + 'openai_Dalle-2_text2image', + 'openai_Dalle-3_text2image', + 'other_Midjourney-v6.0_text2image', + 'other_Midjourney-v5.0_text2image', + "replicate_FLUX.1-schnell_text2image", + "replicate_FLUX.1-pro_text2image", + "replicate_FLUX.1-dev_text2image", + 'other_Meissonic_text2image', + "replicate_FLUX-1.1-pro_text2image", + 'replicate_SD-v3.5-large_text2image', + 'replicate_SD-v3.5-large-turbo_text2image', + ] + +VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video', + 'replicate_Animate-Diff_text2video', + 'replicate_OpenSora_text2video', + 'replicate_LaVie_text2video', + 'replicate_VideoCrafter2_text2video', + 'replicate_Stable-Video-Diffusion_text2video', + 'other_Runway-Gen3_text2video', + 'other_Pika-beta_text2video', + 'other_Pika-v1.0_text2video', + 'other_Runway-Gen2_text2video', + 'other_Sora_text2video', + 'replicate_Cogvideox-5b_text2video', + 'other_KLing-v1.0_text2video', + ] + +B2I_MODELS = ['local_MIGC_b2i', 'huggingface_ReCo_b2i'] + + +def load_pipeline(model_name): + """ + Load a model pipeline based on the model name + Args: + model_name (str): The name of the model to load, should be of the form {source}_{name}_{type} + """ + model_source, model_name, model_type = model_name.split("_") + + if model_source == "replicate": + pipe = load_replicate_model(model_name, model_type) + elif model_source == "huggingface": + pipe = load_huggingface_model(model_name, model_type) + elif model_source == "openai": + pipe = load_openai_model(model_name, model_type) + elif model_source == "other": + pipe = load_other_model(model_name, model_type) + elif model_source == "local": + pipe = load_local_model(model_name, model_type) + else: + raise ValueError(f"Model source {model_source} not supported") + return pipe \ No newline at end of file diff --git a/model/models/__pycache__/__init__.cpython-310.pyc b/model/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c9884477fbd84ec06e7362b6619129524f6fe24 Binary files /dev/null and b/model/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/models/__pycache__/huggingface_models.cpython-310.pyc b/model/models/__pycache__/huggingface_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e310c1dcb478c0d7559412e869cd6fb190c3178 Binary files /dev/null and b/model/models/__pycache__/huggingface_models.cpython-310.pyc differ diff --git a/model/models/__pycache__/local_models.cpython-310.pyc b/model/models/__pycache__/local_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..873fc19038acd0b47f1651b37223164463461e8e Binary files /dev/null and b/model/models/__pycache__/local_models.cpython-310.pyc differ diff --git a/model/models/__pycache__/openai_api_models.cpython-310.pyc b/model/models/__pycache__/openai_api_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768e426dea193f07f67ab47f05556f94ab77bfbf Binary files /dev/null and b/model/models/__pycache__/openai_api_models.cpython-310.pyc differ diff --git a/model/models/__pycache__/other_api_models.cpython-310.pyc b/model/models/__pycache__/other_api_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb494165a0682b488938494af5ab0cd1f65eca1d Binary files /dev/null and b/model/models/__pycache__/other_api_models.cpython-310.pyc differ diff --git a/model/models/__pycache__/replicate_api_models.cpython-310.pyc b/model/models/__pycache__/replicate_api_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcb5a7debe6a4ebce45fdac6d0f7080bacc08e38 Binary files /dev/null and b/model/models/__pycache__/replicate_api_models.cpython-310.pyc differ diff --git a/model/models/huggingface_models.py b/model/models/huggingface_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e6fb5a99675ede1c9d1119f67174248b899370fa --- /dev/null +++ b/model/models/huggingface_models.py @@ -0,0 +1,65 @@ +from diffusers import DiffusionPipeline +from diffusers import AutoPipelineForText2Image +from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline +from diffusers import StableDiffusionPipeline +import torch +import os + + +def load_huggingface_model(model_name, model_type): + if model_name == "SD-turbo": + pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16") + pipe = pipe.to("cuda") + elif model_name == "SDXL-turbo": + pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16") + pipe = pipe.to("cuda") + elif model_name == "Stable-cascade": + prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16) + decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16) + pipe = [prior, decoder] + elif model_name == "ReCo": + path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(path ,torch_dtype=torch.float16) + pipe = pipe.to("cuda") + else: + raise NotImplementedError + # if model_name == "SD-turbo": + # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo") + # elif model_name == "SDXL-turbo": + # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") + # else: + # raise NotImplementedError + # pipe = pipe.to("cpu") + return pipe + + +if __name__ == "__main__": + # for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo" + # pipe = load_huggingface_model(name, "text2image") + + # for name in ["IF-I-XL-v1.0"]: + # pipe = load_huggingface_model(name, 'text2image') + # pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + + prompt = 'draw a tiger' + pipe = load_huggingface_model('Stable-cascade', "text2image") + prior, decoder = pipe + prior.enable_model_cpu_offload() + prior_output = prior( + prompt=prompt, + height=512, + width=512, + negative_prompt='', + guidance_scale=4.0, + num_images_per_prompt=1, + num_inference_steps=20 + ) + decoder.enable_model_cpu_offload() + result = decoder( + image_embeddings=prior_output.image_embeddings.to(torch.float16), + prompt=prompt, + negative_prompt='', + guidance_scale=0.0, + output_type="pil", + num_inference_steps=10 + ).images[0] \ No newline at end of file diff --git a/model/models/local_models.py b/model/models/local_models.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b245e7b5f88e23091d9970b8f7071eb47254fc --- /dev/null +++ b/model/models/local_models.py @@ -0,0 +1,16 @@ +import os +import sys + +migc_path = os.path.dirname(os.path.abspath(__file__)) +print(migc_path) +if migc_path not in sys.path: + sys.path.append(migc_path) + +from model_bbox.MIGC.inference_single_image import MIGC_Pipe + +def load_local_model(model_name, model_type): + if model_name == "MIGC": + pipe = MIGC_Pipe() + else: + raise NotImplementedError + return pipe \ No newline at end of file diff --git a/model/models/openai_api_models.py b/model/models/openai_api_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b440aca5cfb7d31de7634ec827330d4cb1cb37bf --- /dev/null +++ b/model/models/openai_api_models.py @@ -0,0 +1,57 @@ +from openai import OpenAI +from PIL import Image +import requests +import io +import os +import base64 + + +class OpenaiModel(): + def __init__(self, model_name, model_type): + self.model_name = model_name + self.model_type = model_type + + def __call__(self, *args, **kwargs): + if self.model_type == "text2image": + assert "prompt" in kwargs, "prompt is required for text2image model" + + client = OpenAI() + + if 'Dalle-3' in self.model_name: + client = OpenAI() + response = client.images.generate( + model="dall-e-3", + prompt=kwargs["prompt"], + size="1024x1024", + quality="standard", + n=1, + ) + elif 'Dalle-2' in self.model_name: + client = OpenAI() + response = client.images.generate( + model="dall-e-2", + prompt=kwargs["prompt"], + size="512x512", + quality="standard", + n=1, + ) + else: + raise NotImplementedError + + result_url = response.data[0].url + response = requests.get(result_url) + result = Image.open(io.BytesIO(response.content)) + return result + else: + raise ValueError("model_type must be text2image or image2image") + + +def load_openai_model(model_name, model_type): + return OpenaiModel(model_name, model_type) + + +if __name__ == "__main__": + pipe = load_openai_model('Dalle-3', 'text2image') + result = pipe(prompt='draw a tiger') + print(result) + diff --git a/model/models/other_api_models.py b/model/models/other_api_models.py new file mode 100644 index 0000000000000000000000000000000000000000..363e9dc148c94c0da77bb326ee1bdce05b8adb0c --- /dev/null +++ b/model/models/other_api_models.py @@ -0,0 +1,91 @@ +import requests +import json +import os +from PIL import Image +import io, time + + +class OtherModel(): + def __init__(self, model_name, model_type): + self.model_name = model_name + self.model_type = model_type + self.image_url = "https://www.xdai.online/mj/submit/imagine" + self.key = os.environ.get('MIDJOURNEY_KEY') + self.get_image_url = "https://www.xdai.online/mj/image/" + self.repeat_num = 5 + + def __call__(self, *args, **kwargs): + if self.model_type == "text2image": + assert "prompt" in kwargs, "prompt is required for text2image model" + if self.model_name == "Midjourney-v6.0": + data = { + "base64Array": [], + "notifyHook": "", + "prompt": "{} --v 6.0".format(kwargs["prompt"]), + "state": "", + "botType": "MID_JOURNEY", + } + elif self.model_name == "Midjourney-v5.0": + data = { + "base64Array": [], + "notifyHook": "", + "prompt": "{} --v 5.0".format(kwargs["prompt"]), + "state": "", + "botType": "MID_JOURNEY", + } + else: + raise NotImplementedError + + headers = { + "Authorization": "Bearer {}".format(self.key), + "Content-Type": "application/json" + } + while 1: + response = requests.post(self.image_url, data=json.dumps(data), headers=headers) + if response.status_code == 200: + print("Submit success!") + response_json = json.loads(response.content.decode('utf-8')) + img_id = response_json["result"] + result_url = self.get_image_url + img_id + print(result_url) + self.repeat_num = 800 + while 1: + time.sleep(1) + img_response = requests.get(result_url) + if img_response.status_code == 200: + result = Image.open(io.BytesIO(img_response.content)) + width, height = result.size + new_width = width // 2 + new_height = height // 2 + result = result.crop((0, 0, new_width, new_height)) + self.repeat_num = 5 + return result + else: + self.repeat_num = self.repeat_num - 1 + if self.repeat_num == 0: + raise ValueError("Image request failed.") + continue + + else: + self.repeat_num = self.repeat_num - 1 + if self.repeat_num == 0: + raise ValueError("API request failed.") + continue + if self.model_type == "text2video": + assert "prompt" in kwargs, "prompt is required for text2video model" + + else: + raise ValueError("model_type must be text2image") + + +def load_other_model(model_name, model_type): + return OtherModel(model_name, model_type) + +if __name__ == "__main__": + import http.client + import json + + pipe = load_other_model("Midjourney-v5.0", "text2image") + result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow") + print(result) + exit() diff --git a/model/models/replicate_api_models.py b/model/models/replicate_api_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b0cb5139024ee35766cf1bf005a82ede37343a18 --- /dev/null +++ b/model/models/replicate_api_models.py @@ -0,0 +1,195 @@ +import replicate +from PIL import Image +import requests +import io +import os +import base64 + +Replicate_MODEl_NAME_MAP = { + "SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", + "SD-v3.0": "stability-ai/stable-diffusion-3", + "SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4", + "SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7", + "SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", + "Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9", + "Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a", + "Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019", + "Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8", + "Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", + "Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615", + "SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93", + "Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969", + "LCM-v1.5": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f", + "Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c", + + "Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08", + "Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1", + "SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6", + "Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144", + "Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d", + + "Zeroscope-v2-xl": "anotherjesse/zeroscope-v2-xl:9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351", + # "Damo-Text-to-Video": "cjwbw/damo-text-to-video:1e205ea73084bd17a0a3b43396e49ba0d6bc2e754e9283b2df49fad2dcf95755", + "Animate-Diff": "lucataco/animate-diff:beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f", + "OpenSora": "camenduru/open-sora:8099e5722ba3d5f408cd3e696e6df058137056268939337a3fbe3912e86e72ad", + "LaVie": "cjwbw/lavie:0bca850c4928b6c30052541fa002f24cbb4b677259c461dd041d271ba9d3c517", + "VideoCrafter2": "lucataco/video-crafter:7757c5775e962c618053e7df4343052a21075676d6234e8ede5fa67c9e43bce0", + "Stable-Video-Diffusion": "sunfjun/stable-video-diffusion:d68b6e09eedbac7a49e3d8644999d93579c386a083768235cabca88796d70d82", + "FLUX.1-schnell": "black-forest-labs/flux-schnell", + "FLUX.1-pro": "black-forest-labs/flux-pro", + "FLUX.1-dev": "black-forest-labs/flux-dev", + "FLUX-1.1-pro": "black-forest-labs/flux-1.1-pro", + "SD-v3.5-large": "stability-ai/stable-diffusion-3.5-large", + "SD-v3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo", + } + + +class ReplicateModel(): + def __init__(self, model_name, model_type): + self.model_name = model_name + self.model_type = model_type + + def __call__(self, *args, **kwargs): + if self.model_type == "text2image": + assert "prompt" in kwargs, "prompt is required for text2image model" + output = replicate.run( + f"{Replicate_MODEl_NAME_MAP[self.model_name]}", + input={ + "width": 512, + "height": 512, + "prompt": kwargs["prompt"] + }, + ) + if 'Openjourney' in self.model_name: + for item in output: + result_url = item + break + elif isinstance(output, list): + result_url = output[0] + else: + result_url = output + print(self.model_name, result_url) + response = requests.get(result_url) + result = Image.open(io.BytesIO(response.content)) + return result + + elif self.model_type == "text2video": + assert "prompt" in kwargs, "prompt is required for text2image model" + if self.model_name == "Zeroscope-v2-xl": + input = { + "fps": 24, + "width": 512, + "height": 512, + "prompt": kwargs["prompt"], + "guidance_scale": 17.5, + # "negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken", + "num_frames": 48, + } + elif self.model_name == "Damo-Text-to-Video": + input={ + "fps": 8, + "prompt": kwargs["prompt"], + "num_frames": 16, + "num_inference_steps": 50 + } + elif self.model_name == "Animate-Diff": + input={ + "path": "toonyou_beta3.safetensors", + "seed": 255224557, + "steps": 25, + "prompt": kwargs["prompt"], + "n_prompt": "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth", + "motion_module": "mm_sd_v14", + "guidance_scale": 7.5 + } + elif self.model_name == "OpenSora": + input={ + "seed": 1234, + "prompt": kwargs["prompt"], + } + elif self.model_name == "LaVie": + input={ + "width": 512, + "height": 512, + "prompt": kwargs["prompt"], + "quality": 9, + "video_fps": 8, + "interpolation": False, + "sample_method": "ddpm", + "guidance_scale": 7, + "super_resolution": False, + "num_inference_steps": 50 + } + elif self.model_name == "VideoCrafter2": + input={ + "fps": 24, + "seed": 64045, + "steps": 40, + "width": 512, + "height": 512, + "prompt": kwargs["prompt"], + } + elif self.model_name == "Stable-Video-Diffusion": + text2image_name = "SD-v2.1" + output = replicate.run( + f"{Replicate_MODEl_NAME_MAP[text2image_name]}", + input={ + "width": 512, + "height": 512, + "prompt": kwargs["prompt"] + }, + ) + if isinstance(output, list): + image_url = output[0] + else: + image_url = output + print(image_url) + + input={ + "cond_aug": 0.02, + "decoding_t": 14, + "input_image": "{}".format(image_url), + "video_length": "14_frames_with_svd", + "sizing_strategy": "maintain_aspect_ratio", + "motion_bucket_id": 127, + "frames_per_second": 6 + } + + output = replicate.run( + f"{Replicate_MODEl_NAME_MAP[self.model_name]}", + input=input, + ) + if isinstance(output, list): + result_url = output[0] + else: + result_url = output + print(self.model_name) + print(result_url) + # response = requests.get(result_url) + # result = Image.open(io.BytesIO(response.content)) + + # for event in handler.iter_events(with_logs=True): + # if isinstance(event, fal_client.InProgress): + # print('Request in progress') + # print(event.logs) + + # result = handler.get() + # print("result video: ====") + # print(result) + # result_url = result['video']['url'] + # return result_url + return result_url + else: + raise ValueError("model_type must be text2image or image2image") + + +def load_replicate_model(model_name, model_type): + return ReplicateModel(model_name, model_type) + + +if __name__ == "__main__": + model_name = 'replicate_zeroscope-v2-xl_text2video' + model_source, model_name, model_type = model_name.split("_") + pipe = load_replicate_model(model_name, model_type) + prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic" + result = pipe(prompt=prompt) \ No newline at end of file diff --git a/model_bbox/.gradio/certificate.pem b/model_bbox/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/model_bbox/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/model_bbox/MIGC/__init__.py b/model_bbox/MIGC/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/MIGC/__pycache__/__init__.cpython-310.pyc b/model_bbox/MIGC/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac4272b00196ab299fa17197f652f508831a306 Binary files /dev/null and b/model_bbox/MIGC/__pycache__/__init__.cpython-310.pyc differ diff --git a/model_bbox/MIGC/__pycache__/inference_single_image.cpython-310.pyc b/model_bbox/MIGC/__pycache__/inference_single_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b6a5da9a89744ae43de79b8f438f7c6c4b4a2c9 Binary files /dev/null and b/model_bbox/MIGC/__pycache__/inference_single_image.cpython-310.pyc differ diff --git a/model_bbox/MIGC/inference_single_image.py b/model_bbox/MIGC/inference_single_image.py new file mode 100644 index 0000000000000000000000000000000000000000..ac324d66f3fc3fb2b6f4b641eb917b53cb949d74 --- /dev/null +++ b/model_bbox/MIGC/inference_single_image.py @@ -0,0 +1,193 @@ +import os +import sys +import torch + +migc_path = os.path.dirname(os.path.abspath(__file__)) +print(migc_path) +if migc_path not in sys.path: + sys.path.append(migc_path) +import yaml +from diffusers import EulerDiscreteScheduler +from migc.migc_utils import seed_everything +from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore + +def normalize_bbox(bboxes, img_width, img_height): + normalized_bboxes = [] + for box in bboxes: + x_min, y_min, x_max, y_max = box + + x_min = x_min / img_width + y_min = y_min / img_height + x_max = x_max / img_width + y_max = y_max / img_height + + normalized_bboxes.append([x_min, y_min, x_max, y_max]) + + return [normalized_bboxes] + +def create_simple_prompt(input_str): + # 先将输入字符串按分号分割,并去掉空字符串 + objects = [obj for obj in input_str.split(';') if obj.strip()] + + # 创建详细描述字符串 + prompt_description = "masterpiece, best quality, " + ", ".join(objects) + + # 创建最终结构 + prompt_final = [[prompt_description] + objects] + + return prompt_final + + +def inference_single_image(prompt, grounding_instruction, state): + print(prompt) + print(grounding_instruction) + bbox = state['boxes'] + print(bbox) + bbox = normalize_bbox(bbox, 600, 600) + print(bbox) + simple_prompt = create_simple_prompt(grounding_instruction) + print(simple_prompt) + migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' + migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) + print(migc_ckpt_path_all) + assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!" + + + sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" + # MIGC is a plug-and-play controller. + # You can go to https://civitai.com/search/models?baseModel=SD%201.4&baseModel=SD%201.5&sortBy=models_v5 find a base model with better generation ability to achieve better creations. + + # Construct MIGC pipeline + pipe = StableDiffusionMIGCPipeline.from_pretrained( + sd1x_path) + pipe.attention_store = AttentionStore() + from migc.migc_utils import load_migc + load_migc(pipe.unet , pipe.attention_store, + migc_ckpt_path_all, attn_processor=MIGCProcessor) + pipe = pipe.to("cuda") + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + + # prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\ + # green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\ + # yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \ + # 'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']] + + # bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \ + # [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \ + # [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \ + # [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \ + # [0.609375, 0.09375, 0.90625, 0.28125]]] + negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' + seed = 7351007268695528845 + seed_everything(seed) + print("Start inference: ") + image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, + MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0] + return image + + + + +# def MIGC_Pipe(): +# migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' +# migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) +# print(migc_ckpt_path_all) +# assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!" +# sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" +# pipe = StableDiffusionMIGCPipeline.from_pretrained( +# sd1x_path) +# pipe.attention_store = AttentionStore() +# from migc.migc_utils import load_migc +# load_migc(pipe.unet , pipe.attention_store, +# migc_ckpt_path_all, attn_processor=MIGCProcessor) +# pipe = pipe.to("cuda") +# pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +# return pipe + +def MIGC_Pipe(): + migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' + migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) + print(f"加载 MIGC 权重文件路径: {migc_ckpt_path_all}") + + assert os.path.isfile(migc_ckpt_path_all), f"请下载 MIGC 的 ckpt 文件并将其放在 'pretrained_weights/' 文件夹中: {migc_ckpt_path_all}" + + sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" + print(f"加载 StableDiffusion 模型: {sd1x_path}") + + # 加载 StableDiffusionMIGCPipeline + print("load sd:") + pipe = StableDiffusionMIGCPipeline.from_pretrained(sd1x_path) + pipe.attention_store = AttentionStore() + + # 导入并加载 MIGC 权重 + print("load migc") + from migc.migc_utils import load_migc + load_migc(pipe.unet, pipe.attention_store, migc_ckpt_path_all, attn_processor=MIGCProcessor) + + # 确保模型和 attention_store 被正确加载 + assert pipe.unet is not None, "unet 模型未正确加载!" + assert pipe.attention_store is not None, "attention_store 未正确加载!" + + # 转移到 CUDA + if torch.cuda.is_available(): + device = torch.device("cuda") + print("使用 CUDA 设备") + else: + device = torch.device("cpu") + print("使用 CPU") + + pipe = pipe.to(device) + + # 设置调度器 + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + return pipe + + +def create_simple_prompt(input_str): + # 先将输入字符串按分号分割,并去掉空字符串 + objects = [obj for obj in input_str.split(';') if obj.strip()] + + # 创建详细描述字符串 + prompt_description = "masterpiece, best quality, " + ", ".join(objects) + + # 创建最终结构 + prompt_final = [[prompt_description] + objects] + + return prompt_final + + +def inference_image(pipe, prompt, grounding_instruction, state): + print(prompt) + print(grounding_instruction) + bbox = state['boxes'] + print(bbox) + bbox = normalize_bbox(bbox, 600, 600) + print(bbox) + simple_prompt = create_simple_prompt(grounding_instruction) + print(simple_prompt) + negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' + seed = 7351007268695528845 + seed_everything(seed) + print("Start inference: ") + image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, + MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0] + return image + + + +if __name__ == "__main__": + prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\ + green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\ + yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \ + 'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']] + + bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \ + [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \ + [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \ + [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \ + [0.609375, 0.09375, 0.90625, 0.28125]]] + image = inference_single_image("a cat", prompt_final, bboxes) + image.save("output.png") + print("done") \ No newline at end of file diff --git a/model_bbox/MIGC/migc/__init__.py b/model_bbox/MIGC/migc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/MIGC/migc/__pycache__/__init__.cpython-310.pyc b/model_bbox/MIGC/migc/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..678c06519878a1c06135d57134f35b28fcca2001 Binary files /dev/null and b/model_bbox/MIGC/migc/__pycache__/__init__.cpython-310.pyc differ diff --git a/model_bbox/MIGC/migc/__pycache__/migc_arch.cpython-310.pyc b/model_bbox/MIGC/migc/__pycache__/migc_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..089c11fb9c4c1cbb926270d6df04558a5243c077 Binary files /dev/null and b/model_bbox/MIGC/migc/__pycache__/migc_arch.cpython-310.pyc differ diff --git a/model_bbox/MIGC/migc/__pycache__/migc_layers.cpython-310.pyc b/model_bbox/MIGC/migc/__pycache__/migc_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156b3ce11535da98dededd31fe137237fe4affcc Binary files /dev/null and b/model_bbox/MIGC/migc/__pycache__/migc_layers.cpython-310.pyc differ diff --git a/model_bbox/MIGC/migc/__pycache__/migc_pipeline.cpython-310.pyc b/model_bbox/MIGC/migc/__pycache__/migc_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0f227461a7ded49a5042ff5997e2889be8a2059 Binary files /dev/null and b/model_bbox/MIGC/migc/__pycache__/migc_pipeline.cpython-310.pyc differ diff --git a/model_bbox/MIGC/migc/__pycache__/migc_utils.cpython-310.pyc b/model_bbox/MIGC/migc/__pycache__/migc_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3eeb056aabfdb8a2fc6ef6268ab65e5c5f8bd76 Binary files /dev/null and b/model_bbox/MIGC/migc/__pycache__/migc_utils.cpython-310.pyc differ diff --git a/model_bbox/MIGC/migc/migc_arch.py b/model_bbox/MIGC/migc/migc_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6f9c813175408fc7f4c11a4aec26b01eb84923 --- /dev/null +++ b/model_bbox/MIGC/migc/migc_arch.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from migc.migc_layers import CBAM, CrossAttention, LayoutAttention + + +class FourierEmbedder(): + def __init__(self, num_freqs=64, temperature=100): + self.num_freqs = num_freqs + self.temperature = temperature + self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs ) + + @ torch.no_grad() + def __call__(self, x, cat_dim=-1): + out = [] + for freq in self.freq_bands: + out.append( torch.sin( freq*x ) ) + out.append( torch.cos( freq*x ) ) + return torch.cat(out, cat_dim) # torch.Size([5, 30, 64]) + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy + + # -------------------------------------------------------------- # + self.linears_position = nn.Sequential( + nn.Linear(self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + def forward(self, boxes): + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*1*4 --> B*1*C torch.Size([5, 1, 64]) + xyxy_embedding = self.linears_position(xyxy_embedding) # B*1*C --> B*1*768 torch.Size([5, 1, 768]) + + return xyxy_embedding + + +class SAC(nn.Module): + def __init__(self, C, number_pro=30): + super().__init__() + self.C = C + self.number_pro = number_pro + self.conv1 = nn.Conv2d(C + 1, C, 1, 1) + self.cbam1 = CBAM(C) + self.conv2 = nn.Conv2d(C, 1, 1, 1) + self.cbam2 = CBAM(number_pro, reduction_ratio=1) + + def forward(self, x, guidance_mask, sac_scale=None): + ''' + :param x: (B, phase_num, HW, C) + :param guidance_mask: (B, phase_num, H, W) + :return: + ''' + B, phase_num, HW, C = x.shape + _, _, H, W = guidance_mask.shape + guidance_mask = guidance_mask.view(guidance_mask.shape[0], phase_num, -1)[ + ..., None] # (B, phase_num, HW, 1) + + null_x = torch.zeros_like(x[:, [0], ...]).to(x.device) + null_mask = torch.zeros_like(guidance_mask[:, [0], ...]).to(guidance_mask.device) + + x = torch.cat([x, null_x], dim=1) + guidance_mask = torch.cat([guidance_mask, null_mask], dim=1) + phase_num += 1 + + + scale = torch.cat([x, guidance_mask], dim=-1) # (B, phase_num, HW, C+1) + scale = scale.view(-1, H, W, C + 1) # (B * phase_num, H, W, C+1) + scale = scale.permute(0, 3, 1, 2) # (B * phase_num, C+1, H, W) + scale = self.conv1(scale) # (B * phase_num, C, H, W) + scale = self.cbam1(scale) # (B * phase_num, C, H, W) + scale = self.conv2(scale) # (B * phase_num, 1, H, W) + scale = scale.view(B, phase_num, H, W) # (B, phase_num, H, W) + + null_scale = scale[:, [-1], ...] + scale = scale[:, :-1, ...] + x = x[:, :-1, ...] + + pad_num = self.number_pro - phase_num + 1 + + ori_phase_num = scale[:, 1:-1, ...].shape[1] + phase_scale = torch.cat([scale[:, 1:-1, ...], null_scale.repeat(1, pad_num, 1, 1)], dim=1) + shuffled_order = torch.randperm(phase_scale.shape[1]) + inv_shuffled_order = torch.argsort(shuffled_order) + + random_phase_scale = phase_scale[:, shuffled_order, ...] + + scale = torch.cat([scale[:, [0], ...], random_phase_scale, scale[:, [-1], ...]], dim=1) + # (B, number_pro, H, W) + + scale = self.cbam2(scale) # (B, number_pro, H, W) + scale = scale.view(B, self.number_pro, HW)[..., None] # (B, number_pro, HW) + + random_phase_scale = scale[:, 1: -1, ...] + phase_scale = random_phase_scale[:, inv_shuffled_order[:ori_phase_num], :] + if sac_scale is not None: + instance_num = len(sac_scale) + for i in range(instance_num): + phase_scale[:, i, ...] = phase_scale[:, i, ...] * sac_scale[i] + + + scale = torch.cat([scale[:, [0], ...], phase_scale, scale[:, [-1], ...]], dim=1) + + scale = scale.softmax(dim=1) # (B, phase_num, HW, 1) + out = (x * scale).sum(dim=1, keepdims=True) # (B, 1, HW, C) + return out, scale + + +class MIGC(nn.Module): + def __init__(self, C, attn_type='base', context_dim=768, heads=8): + super().__init__() + self.ea = CrossAttention(query_dim=C, context_dim=context_dim, + heads=heads, dim_head=C // heads, + dropout=0.0) + self.la = LayoutAttention(query_dim=C, + heads=heads, dim_head=C // heads, + dropout=0.0) + self.norm = nn.LayerNorm(C) + self.sac = SAC(C) + self.pos_net = PositionNet(in_dim=768, out_dim=768) + + def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False): + # x: (B, instance_num+1, HW, C) + # guidance_mask: (B, instance_num, H, W) + # box: (instance_num, 4) + # image_token: (B, instance_num+1, HW, C) + full_H = other_info['height'] + full_W = other_info['width'] + B, _, HW, C = ca_x.shape + instance_num = guidance_mask.shape[1] + down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2])) + H = full_H // down_scale + W = full_W // down_scale + guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W) + + + supplement_mask = other_info['supplement_mask'] # (B, 1, 64, 64) + supplement_mask = F.interpolate(supplement_mask, size=(H, W), mode='bilinear') # (B, 1, H, W) + image_token = other_info['image_token'] + assert image_token.shape == ca_x.shape + context = other_info['context_pooler'] + box = other_info['box'] + box = box.view(B * instance_num, 1, -1) + box_token = self.pos_net(box) + context = torch.cat([context[1:, ...], box_token], dim=1) + ca_scale = other_info['ca_scale'] if 'ca_scale' in other_info else None + ea_scale = other_info['ea_scale'] if 'ea_scale' in other_info else None + sac_scale = other_info['sac_scale'] if 'sac_scale' in other_info else None + + ea_x, ea_attn = self.ea(self.norm(image_token[:, 1:, ...].view(B * instance_num, HW, C)), + context=context, return_attn=True) + ea_x = ea_x.view(B, instance_num, HW, C) + ea_x = ea_x * guidance_mask.view(B, instance_num, HW, 1) + + ca_x[:, 1:, ...] = ca_x[:, 1:, ...] * guidance_mask.view(B, instance_num, HW, 1) # (B, phase_num, HW, C) + if ca_scale is not None: + assert len(ca_scale) == instance_num + for i in range(instance_num): + ca_x[:, i+1, ...] = ca_x[:, i+1, ...] * ca_scale[i] + ea_x[:, i, ...] * ea_scale[i] + else: + ca_x[:, 1:, ...] = ca_x[:, 1:, ...] + ea_x + + ori_image_token = image_token[:, 0, ...] # (B, HW, C) + fusion_template = self.la(x=ori_image_token, guidance_mask=torch.cat([guidance_mask[:, :, ...], supplement_mask], dim=1)) # (B, HW, C) + fusion_template = fusion_template.view(B, 1, HW, C) # (B, 1, HW, C) + + ca_x = torch.cat([ca_x, fusion_template], dim = 1) + ca_x[:, 0, ...] = ca_x[:, 0, ...] * supplement_mask.view(B, HW, 1) + guidance_mask = torch.cat([ + supplement_mask, + guidance_mask, + torch.ones(B, 1, H, W).to(guidance_mask.device) + ], dim=1) + + + out_MIGC, sac_scale = self.sac(ca_x, guidance_mask, sac_scale=sac_scale) + if return_fuser_info: + fuser_info = {} + fuser_info['sac_scale'] = sac_scale.view(B, instance_num + 2, H, W) + fuser_info['ea_attn'] = ea_attn.mean(dim=1).view(B, instance_num, H, W, 2) + return out_MIGC, fuser_info + else: + return out_MIGC + + +class NaiveFuser(nn.Module): + def __init__(self): + super().__init__() + def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False): + # ca_x: (B, instance_num+1, HW, C) + # guidance_mask: (B, instance_num, H, W) + # box: (instance_num, 4) + # image_token: (B, instance_num+1, HW, C) + full_H = other_info['height'] + full_W = other_info['width'] + B, _, HW, C = ca_x.shape + instance_num = guidance_mask.shape[1] + down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2])) + H = full_H // down_scale + W = full_W // down_scale + guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W) + guidance_mask = torch.cat([torch.ones(B, 1, H, W).to(guidance_mask.device), guidance_mask * 10], dim=1) # (B, instance_num+1, H, W) + guidance_mask = guidance_mask.view(B, instance_num + 1, HW, 1) + out_MIGC = (ca_x * guidance_mask).sum(dim=1) / (guidance_mask.sum(dim=1) + 1e-6) + if return_fuser_info: + return out_MIGC, None + else: + return out_MIGC \ No newline at end of file diff --git a/model_bbox/MIGC/migc/migc_layers.py b/model_bbox/MIGC/migc/migc_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..091564b77f7f704c16dbda79233f2c82182e2c40 --- /dev/null +++ b/model_bbox/MIGC/migc/migc_layers.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import random +import math +from inspect import isfunction +from einops import rearrange, repeat +from torch import nn, einsum + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None, + forward_layout_guidance=False): + h = self.heads + b = x.shape[0] + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + if forward_layout_guidance: + # sim: (B * phase_num * h, HW, 77), b = B * phase_num + # guidance_mask: (B, phase_num, 64, 64) + HW = sim.shape[1] + H = W = int(math.sqrt(HW)) + guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='nearest') # (B, phase_num, H, W) + sim = sim.view(b, h, HW, 77) + guidance_mask = guidance_mask.view(b, 1, HW, 1) + guidance_mask[guidance_mask == 1] = 5.0 + guidance_mask[guidance_mask == 0] = 0.1 + sim[:, :, :, 1:] = sim[:, :, :, 1:] * guidance_mask + sim = sim.view(b * h, HW, 77) + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + if need_softmax: + attn = sim.softmax(dim=-1) + else: + attn = sim + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + if return_attn: + attn = attn.view(b, h, attn.shape[-2], attn.shape[-1]) + return self.to_out(out), attn + else: + return self.to_out(out) + + +class LayoutAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., use_lora=False): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.use_lora = use_lora + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None): + h = self.heads + b = x.shape[0] + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + _, phase_num, H, W = guidance_mask.shape + HW = H * W + guidance_mask_o = guidance_mask.view(b * phase_num, HW, 1) + guidance_mask_t = guidance_mask.view(b * phase_num, 1, HW) + guidance_mask_sim = torch.bmm(guidance_mask_o, guidance_mask_t) # (B * phase_num, HW, HW) + guidance_mask_sim = guidance_mask_sim.view(b, phase_num, HW, HW).sum(dim=1) + guidance_mask_sim[guidance_mask_sim > 1] = 1 # (B, HW, HW) + guidance_mask_sim = guidance_mask_sim.view(b, 1, HW, HW) + guidance_mask_sim = guidance_mask_sim.repeat(1, self.heads, 1, 1) + guidance_mask_sim = guidance_mask_sim.view(b * self.heads, HW, HW) # (B * head, HW, HW) + + sim[:, :, :HW][guidance_mask_sim == 0] = -torch.finfo(sim.dtype).max + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + + if need_softmax: + attn = sim.softmax(dim=-1) + else: + attn = sim + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + if return_attn: + attn = attn.view(b, h, attn.shape[-2], attn.shape[-1]) + return self.to_out(out), attn + else: + return self.to_out(out) + + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + elif pool_type=='lp': + lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( lp_pool ) + elif pool_type=='lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp( lse_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out \ No newline at end of file diff --git a/model_bbox/MIGC/migc/migc_pipeline.py b/model_bbox/MIGC/migc/migc_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..94e41a012e78dff8b31257c7b666fd348fafe3f4 --- /dev/null +++ b/model_bbox/MIGC/migc/migc_pipeline.py @@ -0,0 +1,928 @@ +import glob +import random +import time +from typing import Any, Callable, Dict, List, Optional, Union +# import moxing as mox +import numpy as np +import torch +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import Attention +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import logging +from PIL import Image, ImageDraw, ImageFont +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +import inspect +import os +import math +import torch.nn as nn +import torch.nn.functional as F +# from utils import load_utils +import argparse +import yaml +import cv2 +import math +from migc.migc_arch import MIGC, NaiveFuser +from scipy.ndimage import uniform_filter, gaussian_filter + +logger = logging.get_logger(__name__) + +class AttentionStore: + @staticmethod + def get_empty_store(): + return {"down": [], "mid": [], "up": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if is_cross: + if attn.shape[1] in self.attn_res: + self.step_store[place_in_unet].append(attn) + + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.between_steps() + + def between_steps(self): + self.attention_store = self.step_store + self.step_store = self.get_empty_store() + + def maps(self, block_type: str): + return self.attention_store[block_type] + + def reset(self): + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, attn_res=[64*64, 32*32, 16*16, 8*8]): + """ + Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion + process + """ + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + self.curr_step_index = 0 + self.attn_res = attn_res + + +def get_sup_mask(mask_list): + or_mask = np.zeros_like(mask_list[0]) + for mask in mask_list: + or_mask += mask + or_mask[or_mask >= 1] = 1 + sup_mask = 1 - or_mask + return sup_mask + + +class MIGCProcessor(nn.Module): + def __init__(self, config, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + self.not_use_migc = config['not_use_migc'] + self.naive_fuser = NaiveFuser() + self.embedding = {} + if not self.not_use_migc: + self.migc = MIGC(config['C']) + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + prompt_nums=[], + bboxes=[], + ith=None, + embeds_pooler=None, + timestep=None, + height=512, + width=512, + MIGCsteps=20, + NaiveFuserSteps=-1, + ca_scale=None, + ea_scale=None, + sac_scale=None, + use_sa_preserve=False, + sa_preserve=False, + ): + batch_size, sequence_length, _ = hidden_states.shape + assert(batch_size == 2, "We currently only implement sampling with batch_size=1, \ + and we will implement sampling with batch_size=N as soon as possible.") + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + instance_num = len(bboxes[0]) + + if ith > MIGCsteps: + not_use_migc = True + else: + not_use_migc = self.not_use_migc + is_vanilla_cross = (not_use_migc and ith > NaiveFuserSteps) + if instance_num == 0: + is_vanilla_cross = True + + is_cross = encoder_hidden_states is not None + + ori_hidden_states = hidden_states.clone() + + # Only Need Negative Prompt and Global Prompt. + if is_cross and is_vanilla_cross: + encoder_hidden_states = encoder_hidden_states[:2, ...] + + # In this case, we need to use MIGC or naive_fuser, so we copy the hidden_states_cond (instance_num+1) times for QKV + if is_cross and not is_vanilla_cross: + hidden_states_uncond = hidden_states[[0], ...] + hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1) + hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond]) + + # QKV Operation of Vanilla Self-Attention or Cross-Attention + query = attn.to_q(hidden_states) + + if ( + not is_cross + and use_sa_preserve + and timestep.item() in self.embedding + and self.place_in_unet == "up" + ): + hidden_states = torch.cat((hidden_states, torch.from_numpy(self.embedding[timestep.item()]).to(hidden_states.device)), dim=1) + + if not is_cross and sa_preserve and self.place_in_unet == "up": + self.embedding[timestep.item()] = ori_hidden_states.cpu().numpy() + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) # 48 4096 77 + self.attnstore(attention_probs, is_cross, self.place_in_unet) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + ###### Self-Attention Results ###### + if not is_cross: + return hidden_states + + ###### Vanilla Cross-Attention Results ###### + if is_vanilla_cross: + return hidden_states + + ###### Cross-Attention with MIGC ###### + assert (not is_vanilla_cross) + # hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output. + hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C]) + cond_ca_output = hidden_states[1: , ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280]) + guidance_masks = [] + in_box = [] + # Construct Instance Guidance Mask + for bbox in bboxes[0]: + guidance_mask = np.zeros((height, width)) + w_min = int(width * bbox[0]) + w_max = int(width * bbox[2]) + h_min = int(height * bbox[1]) + h_max = int(height * bbox[3]) + guidance_mask[h_min: h_max, w_min: w_max] = 1.0 + guidance_masks.append(guidance_mask[None, ...]) + in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]]) + + # Construct Background Guidance Mask + sup_mask = get_sup_mask(guidance_masks) + supplement_mask = torch.from_numpy(sup_mask[None, ...]) + supplement_mask = F.interpolate(supplement_mask, (height//8, width//8), mode='bilinear').float() + supplement_mask = supplement_mask.to(hidden_states.device) # (1, 1, H, W) + + guidance_masks = np.concatenate(guidance_masks, axis=0) + guidance_masks = guidance_masks[None, ...] + guidance_masks = torch.from_numpy(guidance_masks).float().to(cond_ca_output.device) + guidance_masks = F.interpolate(guidance_masks, (height//8, width//8), mode='bilinear') # (1, instance_num, H, W) + + in_box = torch.from_numpy(np.array(in_box))[None, ...].float().to(cond_ca_output.device) # (1, instance_num, 4) + + other_info = {} + other_info['image_token'] = hidden_states_cond[None, ...] + other_info['context'] = encoder_hidden_states[1:, ...] + other_info['box'] = in_box + other_info['context_pooler'] =embeds_pooler # (instance_num, 1, 768) + other_info['supplement_mask'] = supplement_mask + other_info['attn2'] = None + other_info['attn'] = attn + other_info['height'] = height + other_info['width'] = width + other_info['ca_scale'] = ca_scale + other_info['ea_scale'] = ea_scale + other_info['sac_scale'] = sac_scale + + if not not_use_migc: + hidden_states_cond, fuser_info = self.migc(cond_ca_output, + guidance_masks, + other_info=other_info, + return_fuser_info=True) + else: + hidden_states_cond, fuser_info = self.naive_fuser(cond_ca_output, + guidance_masks, + other_info=other_info, + return_fuser_info=True) + hidden_states_cond = hidden_states_cond.squeeze(1) + + hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond]) + return hidden_states + + +class StableDiffusionMIGCPipeline(StableDiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + # Get the parameter signature of the parent class constructor + parent_init_signature = inspect.signature(super().__init__) + parent_init_params = parent_init_signature.parameters + + # Dynamically build a parameter dictionary based on the parameters of the parent class constructor + init_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "safety_checker": safety_checker, + "feature_extractor": feature_extractor, + "requires_safety_checker": requires_safety_checker + } + if 'image_encoder' in parent_init_params.items(): + init_kwargs['image_encoder'] = image_encoder + super().__init__(**init_kwargs) + + self.instance_set = set() + self.embedding = {} + + def _encode_prompt( + self, + prompts, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompts is not None and isinstance(prompts, str): + batch_size = 1 + elif prompts is not None and isinstance(prompts, list): + batch_size = len(prompts) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds_none_flag = (prompt_embeds is None) + prompt_embeds_list = [] + embeds_pooler_list = [] + for prompt in prompts: + if prompt_embeds_none_flag: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + embeds_pooler = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + embeds_pooler = embeds_pooler.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + embeds_pooler = embeds_pooler.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + embeds_pooler = embeds_pooler.view( + bs_embed * num_images_per_prompt, -1 + ) + prompt_embeds_list.append(prompt_embeds) + embeds_pooler_list.append(embeds_pooler) + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + embeds_pooler = torch.cat(embeds_pooler_list, dim=0) + # negative_prompt_embeds: (prompt_nums[0]+prompt_nums[1]+...prompt_nums[n], token_num, token_channel), + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + negative_prompt = "worst quality, low quality, bad anatomy" + uncond_tokens = [negative_prompt] * batch_size + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # negative_prompt_embeds: (len(prompt_nums), token_num, token_channel), + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + final_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return final_prompt_embeds, prompt_embeds, embeds_pooler[:, None, :] + + def check_inputs( + self, + prompt, + token_indices, + bboxes, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if token_indices is not None: + if isinstance(token_indices, list): + if isinstance(token_indices[0], list): + if isinstance(token_indices[0][0], list): + token_indices_batch_size = len(token_indices) + elif isinstance(token_indices[0][0], int): + token_indices_batch_size = 1 + else: + raise TypeError( + "`token_indices` must be a list of lists of integers or a list of integers." + ) + else: + raise TypeError( + "`token_indices` must be a list of lists of integers or a list of integers." + ) + else: + raise TypeError( + "`token_indices` must be a list of lists of integers or a list of integers." + ) + + if bboxes is not None: + if isinstance(bboxes, list): + if isinstance(bboxes[0], list): + if ( + isinstance(bboxes[0][0], list) + and len(bboxes[0][0]) == 4 + and all(isinstance(x, float) for x in bboxes[0][0]) + ): + bboxes_batch_size = len(bboxes) + elif ( + isinstance(bboxes[0], list) + and len(bboxes[0]) == 4 + and all(isinstance(x, float) for x in bboxes[0]) + ): + bboxes_batch_size = 1 + else: + print(isinstance(bboxes[0], list), len(bboxes[0])) + raise TypeError( + "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." + ) + else: + print(isinstance(bboxes[0], list), len(bboxes[0])) + raise TypeError( + "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." + ) + else: + print(isinstance(bboxes[0], list), len(bboxes[0])) + raise TypeError( + "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats." + ) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if token_indices_batch_size != prompt_batch_size: + raise ValueError( + f"token indices batch size must be same as prompt batch size. token indices batch size: {token_indices_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + if bboxes_batch_size != prompt_batch_size: + raise ValueError( + f"bbox batch size must be same as prompt batch size. bbox batch size: {bboxes_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def get_indices(self, prompt: str) -> Dict[str, int]: + """Utility function to list the indices of the tokens you wish to alte""" + ids = self.tokenizer(prompt).input_ids + indices = { + i: tok + for tok, i in zip( + self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)) + ) + } + return indices + + @staticmethod + def draw_box(pil_img: Image, bboxes: List[List[float]]) -> Image: + """Utility function to draw bbox on the image""" + width, height = pil_img.size + draw = ImageDraw.Draw(pil_img) + + for obj_box in bboxes: + x_min, y_min, x_max, y_max = ( + obj_box[0] * width, + obj_box[1] * height, + obj_box[2] * width, + obj_box[3] * height, + ) + draw.rectangle( + [int(x_min), int(y_min), int(x_max), int(y_max)], + outline="red", + width=4, + ) + + return pil_img + + + @staticmethod + def draw_box_desc(pil_img: Image, bboxes: List[List[float]], prompt: List[str]) -> Image: + """Utility function to draw bbox on the image""" + color_list = ['red', 'blue', 'yellow', 'purple', 'green', 'black', 'brown', 'orange', 'white', 'gray'] + width, height = pil_img.size + draw = ImageDraw.Draw(pil_img) + font_folder = os.path.dirname(os.path.dirname(__file__)) + font_path = os.path.join(font_folder, 'Rainbow-Party-2.ttf') + font = ImageFont.truetype(font_path, 30) + + for box_id in range(len(bboxes)): + obj_box = bboxes[box_id] + text = prompt[box_id] + fill = 'black' + for color in prompt[box_id].split(' '): + if color in color_list: + fill = color + text = text.split(',')[0] + x_min, y_min, x_max, y_max = ( + obj_box[0] * width, + obj_box[1] * height, + obj_box[2] * width, + obj_box[3] * height, + ) + draw.rectangle( + [int(x_min), int(y_min), int(x_max), int(y_max)], + outline=fill, + width=4, + ) + draw.text((int(x_min), int(y_min)), text, fill=fill, font=font) + + return pil_img + + + @torch.no_grad() + def __call__( + self, + prompt: List[List[str]] = None, + bboxes: List[List[List[float]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + MIGCsteps=20, + NaiveFuserSteps=-1, + ca_scale=None, + ea_scale=None, + sac_scale=None, + aug_phase_with_and=False, + sa_preserve=False, + use_sa_preserve=False, + clear_set=False, + GUI_progress=None + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + token_indices (Union[List[List[List[int]]], List[List[int]]], optional): + The list of the indexes in the prompt to layout. Defaults to None. + bboxes (Union[List[List[List[float]]], List[List[float]]], optional): + The bounding boxes of the indexes to maintain layout in the image. Defaults to None. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + max_guidance_iter (`int`, *optional*, defaults to `10`): + The maximum number of iterations for the layout guidance on attention maps in diffusion mode. + max_guidance_iter_per_step (`int`, *optional*, defaults to `5`): + The maximum number of iterations to run during each time step for layout guidance. + scale_factor (`int`, *optional*, defaults to `50`): + The scale factor used to update the latents during optimization. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + def aug_phase_with_and_function(phase, instance_num): + instance_num = min(instance_num, 7) + copy_phase = [phase] * instance_num + phase = ', and '.join(copy_phase) + return phase + + if aug_phase_with_and: + instance_num = len(prompt[0]) - 1 + for i in range(1, len(prompt[0])): + prompt[0][i] = aug_phase_with_and_function(prompt[0][i], + instance_num) + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_nums = [0] * len(prompt) + for i, _ in enumerate(prompt): + prompt_nums[i] = len(_) + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, cond_prompt_embeds, embeds_pooler = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # print(prompt_embeds.shape) 3 77 768 + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if clear_set: + self.instance_set = set() + self.embedding = {} + + now_set = set() + for i in range(len(bboxes[0])): + now_set.add((tuple(bboxes[0][i]), prompt[0][i + 1])) + + mask_set = (now_set | self.instance_set) - (now_set & self.instance_set) + self.instance_set = now_set + + guidance_mask = np.full((4, height // 8, width // 8), 1.0) + + for bbox, _ in mask_set: + w_min = max(0, int(width * bbox[0] // 8) - 5) + w_max = min(width, int(width * bbox[2] // 8) + 5) + h_min = max(0, int(height * bbox[1] // 8) - 5) + h_max = min(height, int(height * bbox[3] // 8) + 5) + guidance_mask[:, h_min:h_max, w_min:w_max] = 0 + + kernal_size = 5 + guidance_mask = uniform_filter( + guidance_mask, axes = (1, 2), size = kernal_size + ) + + guidance_mask = torch.from_numpy(guidance_mask).to(self.device).unsqueeze(0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if GUI_progress is not None: + GUI_progress[0] = int((i + 1) / len(timesteps) * 100) + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + cross_attention_kwargs = {'prompt_nums': prompt_nums, + 'bboxes': bboxes, + 'ith': i, + 'embeds_pooler': embeds_pooler, + 'timestep': t, + 'height': height, + 'width': width, + 'MIGCsteps': MIGCsteps, + 'NaiveFuserSteps': NaiveFuserSteps, + 'ca_scale': ca_scale, + 'ea_scale': ea_scale, + 'sac_scale': sac_scale, + 'sa_preserve': sa_preserve, + 'use_sa_preserve': use_sa_preserve} + + self.unet.eval() + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + step_output = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ) + latents = step_output.prev_sample + + ori_input = latents.detach().clone() + if use_sa_preserve and i in self.embedding: + latents = ( + latents * (1.0 - guidance_mask) + + torch.from_numpy(self.embedding[i]).to(latents.device) * guidance_mask + ).float() + + if sa_preserve: + self.embedding[i] = ori_input.cpu().numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=None + ) \ No newline at end of file diff --git a/model_bbox/MIGC/migc/migc_utils.py b/model_bbox/MIGC/migc/migc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4129ec22c36f492b41aad15f3fed3ba480e4922e --- /dev/null +++ b/model_bbox/MIGC/migc/migc_utils.py @@ -0,0 +1,143 @@ +import argparse +import numpy as np +import torch +import os +import yaml +import random +from diffusers.utils.import_utils import is_accelerate_available +from transformers import CLIPTextModel, CLIPTokenizer +from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore +from diffusers import EulerDiscreteScheduler +if is_accelerate_available(): + from accelerate import init_empty_weights +from contextlib import nullcontext + + +def seed_everything(seed): + # np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + + +import torch +from typing import Callable, Dict, List, Optional, Union +from collections import defaultdict + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" + +# We need to set Attention Processors for the following keys. +all_processor_keys = [ + 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', + 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', + 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', + 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', + 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', + 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', + 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', + 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', + 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', + 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', + 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', + 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', + 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', + 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', + 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', + 'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor' +] + +def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor, + **kwargs): + + state_dict = torch.load(pretrained_MIGC_path, map_location="cpu") + + # fill attn processors + attn_processors = {} + state_dict = state_dict['state_dict'] + + + adapter_grouped_dict = defaultdict(dict) + + # change the key of MIGC.ckpt as the form of diffusers unet + for key, value in state_dict.items(): + key_list = key.split(".") + assert 'migc' in key_list + if 'input_blocks' in key_list: + model_type = 'down_blocks' + elif 'middle_block' in key_list: + model_type = 'mid_block' + else: + model_type = 'up_blocks' + index_number = int(key_list[3]) + if model_type == 'down_blocks': + input_num1 = str(index_number//3) + input_num2 = str((index_number%3)-1) + elif model_type == 'mid_block': + input_num1 = '0' + input_num2 = '0' + else: + input_num1 = str(index_number//3) + input_num2 = str(index_number%3) + attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0'] + if model_type == 'mid_block': + attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0'] + attn_processor_key = '.'.join(attn_key_list) + sub_key = '.'.join(key_list[key_list.index('migc'):]) + adapter_grouped_dict[attn_processor_key][sub_key] = value + + # Create MIGC Processor + config = {'not_use_migc': False} + for key, value_dict in adapter_grouped_dict.items(): + dim = value_dict['migc.norm.bias'].shape[0] + config['C'] = dim + key_final = key + '.attn2.processor' + if key_final.startswith("mid_block"): + place_in_unet = "mid" + elif key_final.startswith("up_blocks"): + place_in_unet = "up" + elif key_final.startswith("down_blocks"): + place_in_unet = "down" + + attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet) + attn_processors[key_final].load_state_dict(value_dict) + attn_processors[key_final].to(device=unet.device, dtype=unet.dtype) + + # Create CrossAttention/SelfAttention Processor + config = {'not_use_migc': True} + for key in all_processor_keys: + if key not in attn_processors.keys(): + if key.startswith("mid_block"): + place_in_unet = "mid" + elif key.startswith("up_blocks"): + place_in_unet = "up" + elif key.startswith("down_blocks"): + place_in_unet = "down" + attn_processors[key] = attn_processor(config, attention_store, place_in_unet) + unet.set_attn_processor(attn_processors) + attention_store.num_att_layers = 32 + + +def offlinePipelineSetupWithSafeTensor(sd_safetensors_path): + project_dir = os.path.dirname(os.path.dirname(__file__)) + migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt') + clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder') + clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer') + original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml') + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + # text_encoder = CLIPTextModel(config) + text_encoder = CLIPTextModel.from_pretrained(clip_model_path) + tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path) + pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path, + original_config_file=original_config_file, + text_encoder=text_encoder, + tokenizer=tokenizer, + load_safety_checker=False) + print('Initializing pipeline') + pipe.attention_store = AttentionStore() + from migc.migc_utils import load_migc + load_migc(pipe.unet , pipe.attention_store, + migc_ckpt_path, attn_processor=MIGCProcessor) + + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + return pipe \ No newline at end of file diff --git a/model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt b/model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..adc893978f7cfe1c151eec659d61a6dd22e39f0e --- /dev/null +++ b/model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81756dd19f7c75f9bba1ead1e6f8fcdfb00030cabb01dc46edd85d950236884c +size 229514282 diff --git a/model_bbox/MIGC/pretrained_weights/PUT_MIGC_CKPT_HERE b/model_bbox/MIGC/pretrained_weights/PUT_MIGC_CKPT_HERE new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/ReCo/__init__.py b/model_bbox/ReCo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/ReCo/__pycache__/__init__.cpython-310.pyc b/model_bbox/ReCo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12382119e258e497bff9205167000368e94cb829 Binary files /dev/null and b/model_bbox/ReCo/__pycache__/__init__.cpython-310.pyc differ diff --git a/model_bbox/ReCo/__pycache__/inference.cpython-310.pyc b/model_bbox/ReCo/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43167919f27bd93c9f704326a5678b7ab61d97cf Binary files /dev/null and b/model_bbox/ReCo/__pycache__/inference.cpython-310.pyc differ diff --git a/model_bbox/ReCo/inference.py b/model_bbox/ReCo/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e985bdb8413d6db237575958f541378d9b14b4 --- /dev/null +++ b/model_bbox/ReCo/inference.py @@ -0,0 +1,112 @@ +import os +os.environ ['HF_ENDPOINT'] = 'https://hf-mirror.com' +from ast import main +from numpy import imag +import torch +from diffusers import StableDiffusionPipeline +import os +from PIL import Image + +def normalize_bbox(bboxes, img_width, img_height): + normalized_bboxes = [] + for box in bboxes: + x_min, y_min, x_max, y_max = box + + x_min = (x_min / img_width) + y_min = (y_min / img_height) + x_max = (x_max / img_width) + y_max = (y_max / img_height) + + normalized_bboxes.append([x_min, y_min, x_max, y_max]) + + return normalized_bboxes + +def create_reco_prompt( + caption: str = '', + phrases=[], + boxes=[], + normalize_boxes=True, + image_resolution=512, + num_bins=1000, + ): + """ + method to create ReCo prompt + + caption: global caption + phrases: list of regional captions + boxes: list of regional coordinates (unnormalized xyxy) + """ + + SOS_token = '<|startoftext|>' + EOS_token = '<|endoftext|>' + + box_captions_with_coords = [] + + box_captions_with_coords += [caption] + box_captions_with_coords += [EOS_token] + + for phrase, box in zip(phrases, boxes): + + if normalize_boxes: + box = [float(x) / image_resolution for x in box] + + # quantize into bins + quant_x0 = int(round((box[0] * (num_bins - 1)))) + quant_y0 = int(round((box[1] * (num_bins - 1)))) + quant_x1 = int(round((box[2] * (num_bins - 1)))) + quant_y1 = int(round((box[3] * (num_bins - 1)))) + + # ReCo format + # Add SOS/EOS before/after regional captions + box_captions_with_coords += [ + f"", + f"", + f"", + f"", + SOS_token, + phrase, + EOS_token + ] + + text = " ".join(box_captions_with_coords) + return text + +def inference_image(pipe, prompt, grounding_instruction, state): + print(prompt) + print(grounding_instruction) + bbox = state['boxes'] + # bbox = state + print(bbox) + bbox = normalize_bbox(bbox, 600, 600) + print(bbox) + objects = [obj for obj in grounding_instruction.split(';') if obj.strip()] + print(objects) + prompt_reco = create_reco_prompt(prompt, objects, bbox, normalize_boxes=False) + print(prompt_reco) + image = pipe(prompt_reco, guidance_scale=4).images[0] + return image + + + +if __name__ == "__main__": + path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained( + "j-min/reco_sd14_coco", + torch_dtype=torch.float16 + ) + pipe = pipe.to("cuda") + # caption = "A box contains six donuts with varying types of glazes and toppings." + # phrases = ["chocolate donut.", "dark vanilla donut.", "donut with sprinkles.", "donut with powdered sugar.", "pink donut.", "brown donut."] + # boxes = [[263.68, 294.912, 380.544, 392.832], [121.344, 265.216, 267.392, 401.92], [391.168, 294.912, 506.368, 381.952], [120.064, 143.872, 268.8, 270.336], [264.192, 132.928, 393.216, 263.68], [386.048, 148.48, 490.688, 259.584]] + # prompt = create_reco_prompt(caption, phrases, boxes) + # print(prompt) + # generated_image = pipe( + # prompt, + # guidance_scale=4).images[0] + # generated_image.save("output1.jpg") + prompt = "a dog and a cat;" + grounding_instruction = "cut dog; big cat;" + bbox = [(136, 252, 280, 455), (284, 205, 480, 500)] + + inference_image(pipe, prompt, grounding_instruction, bbox) + diff --git a/model_bbox/__init__.py b/model_bbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/__pycache__/__init__.cpython-310.pyc b/model_bbox/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90702141dca9bd85fa85ee649b28e78137530e64 Binary files /dev/null and b/model_bbox/__pycache__/__init__.cpython-310.pyc differ diff --git a/model_bbox/ksort-logs/vote_log/gr_web_image_editing.log b/model_bbox/ksort-logs/vote_log/gr_web_image_editing.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/ksort-logs/vote_log/gr_web_image_editing_multi.log b/model_bbox/ksort-logs/vote_log/gr_web_image_editing_multi.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_bbox/ksort-logs/vote_log/gr_web_image_generation.log b/model_bbox/ksort-logs/vote_log/gr_web_image_generation.log new file mode 100644 index 0000000000000000000000000000000000000000..3cc0e3bedfd67d2595bd0a08bbf4f562bdd92652 --- /dev/null +++ b/model_bbox/ksort-logs/vote_log/gr_web_image_generation.log @@ -0,0 +1,267 @@ +2025-01-01 10:21:52 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. +2025-01-01 10:21:52 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) +2025-01-01 10:21:53 | INFO | stdout | /home/bcy/projects/Arena/Control_Ability_Arena/model/models +2025-01-01 10:21:53 | INFO | stdout | /home/bcy/projects/Arena/Control_Ability_Arena/model_bbox/MIGC +2025-01-01 10:21:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(, >), received 11. +2025-01-01 10:21:54 | ERROR | stderr | warnings.warn( +2025-01-01 10:21:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(, >), received 11. +2025-01-01 10:21:54 | ERROR | stderr | warnings.warn( +2025-01-01 10:21:54 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860 +2025-01-01 10:22:07 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:07 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:07 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:08 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:08 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:08 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:12 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:12 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:12 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:15 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:15 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:15 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:15 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:15 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:15 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:16 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:16 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:16 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:17 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:17 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:17 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:25 | INFO | stdout | +2025-01-01 10:22:25 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3. +2025-01-01 10:22:25 | INFO | stdout | +2025-01-01 10:22:25 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: +2025-01-01 10:22:25 | INFO | stdout | +2025-01-01 10:22:25 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2025-01-01 10:22:25 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +2025-01-01 10:22:25 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio +2025-01-01 10:22:41 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:41 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:41 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:42 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:42 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:42 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:42 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:42 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:42 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:43 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:43 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:43 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:43 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:44 | INFO | stdout | background.shape (600, 600, 4) +2025-01-01 10:22:44 | INFO | stdout | len(layers) 1 +2025-01-01 10:22:44 | INFO | stdout | composite.shape (600, 600, 4) +2025-01-01 10:22:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None. +2025-01-01 10:22:47 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.") +2025-01-01 10:22:47 | INFO | stdout | [0, 1] +2025-01-01 10:22:47 | INFO | stdout | ['local_MIGC_b2i', 'huggingface_ReCo_b2i'] +2025-01-01 10:22:48 | INFO | stdout | 加载 MIGC 权重文件路径: /home/bcy/projects/Arena/Control_Ability_Arena/model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt +2025-01-01 10:22:48 | INFO | stdout | 加载 StableDiffusion 模型: /share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b +2025-01-01 10:22:48 | INFO | stdout | load sd: +2025-01-01 10:22:48 | ERROR | stderr | Loading pipeline components...: 0%| | 0/7 [00:00 <|startoftext|> naked girl <|endoftext|> +2025-01-01 10:22:55 | INFO | stdout | naked girl +2025-01-01 10:22:55 | INFO | stdout | naked girl; +2025-01-01 10:22:55 | INFO | stdout | [(0, 6, 585, 583)] +2025-01-01 10:22:55 | INFO | stdout | [[[0.0, 0.01, 0.975, 0.9716666666666667]]] +2025-01-01 10:22:55 | INFO | stdout | [['masterpiece, best quality, naked girl', 'naked girl']] +2025-01-01 10:22:55 | INFO | stdout | Start inference: +2025-01-01 10:22:56 | ERROR | stderr | 0%| | 0/50 [00:00 <|startoftext|> dog <|endoftext|> +2025-01-01 10:23:43 | INFO | stdout | dog +2025-01-01 10:23:43 | INFO | stdout | dog +2025-01-01 10:23:43 | INFO | stdout | [(165, 202, 407, 486)] +2025-01-01 10:23:43 | INFO | stdout | [[[0.275, 0.33666666666666667, 0.6783333333333333, 0.81]]] +2025-01-01 10:23:43 | INFO | stdout | [['masterpiece, best quality, dog', 'dog']] +2025-01-01 10:23:43 | INFO | stdout | Start inference: +2025-01-01 10:23:43 | ERROR | stderr | 0%| | 0/50 [00:00 1: + diff_mask = mask - last_mask + else: + diff_mask = np.zeros([]) + + # 根据 mask 的变化来计算 box 的位置 + if diff_mask.sum() > 0: + x1x2 = np.where(diff_mask.max(0) != 0)[0] + y1y2 = np.where(diff_mask.max(1) != 0)[0] + y1, y2 = y1y2.min(), y1y2.max() + x1, x2 = x1x2.min(), x1x2.max() + + if (x2 - x1 > 5) and (y2 - y1 > 5): + state['masks'].append(mask.copy()) + state['boxes'].append((x1, y1, x2, y2)) + + # 处理 grounding_texts + grounding_texts = [x.strip() for x in grounding_texts.split(';')] + grounding_texts = [x for x in grounding_texts if len(x) > 0] + if len(grounding_texts) < len(state['boxes']): + grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))] + + # 绘制标注框 + box_image = draw_box(state['boxes'], grounding_texts, background) + + if box_image is not None and state.get('inpaint_hw', None): + inpaint_hw = state['inpaint_hw'] + box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw))) + original_image = state['original_image'].copy() + box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw) + + return [box_image, new_image_trigger, 1.0, state] + +def build_side_by_side_bbox_ui_anony(models): + notice_markdown = """ + # ⚔️ Control-Ability-Arena (Bbox-to-Image Generation) ⚔️ + ## 📜 Rules + - Input a prompt for four anonymized models and vote on their outputs. + - Two voting modes available: Rank Mode and Best Mode. Switch freely between modes. Please note that ties are always allowed. In ranking mode, users can input rankings like 1 3 3 1. Any invalid rankings, such as 1 4 4 1, will be automatically corrected during post-processing. + - Users are encouraged to make evaluations based on subjective preferences. Evaluation criteria: Alignment (50%) + Aesthetics (50%). + - Alignment includes: Entity Matching (30%) + Style Matching (20%); + - Aesthetics includes: Photorealism (30%) + Light and Shadow (10%) + Absence of Artifacts (10%). + + ## 👇 Generating now! + - Note: Due to the API's image safety checks, errors may occur. If this happens, please re-enter a different prompt. + - At times, high API concurrency can cause congestion, potentially resulting in a generation time of up to 1.5 minutes per image. Thank you for your patience. + """ + model_list = models.model_b2i_list + + state = gr.State({}) + state0 = gr.State() + state1 = gr.State() + state2 = gr.State() + state3 = gr.State() + + gen_func = partial(generate_b2i_annoy, models.generate_image_b2i_parallel_anony) + # gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony) + + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + + with gr.Row(): + sketch_pad_trigger = gr.Number(value=0, visible=False) + sketch_pad_resize_trigger = gr.Number(value=0, visible=False) + image_scale = gr.Number(value=0, elem_id="image_scale", visible=False) + + with gr.Row(): + sketch_pad = gr.ImageEditor( + label="Sketch Pad", + type="numpy", + crop_size="1:1", + width=512, + height=512 + ) + out_imagebox = gr.Image( + type="pil", + label="Parsed Sketch Pad", + width=512, + height=512 + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=True, + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue") + + with gr.Row(): + grounding_instruction = gr.Textbox( + label="Grounding instruction (Separated by semicolon)", + placeholder="👉 Enter your Grounding instruction (e.g. a cat; a dog; a bird; a fish)", + ) + + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion("🔍 Expand to see all Arena players", open=False): + # model_description_md = get_model_description_md(model_list) + gr.Markdown("", elem_id="model_description_markdown") + + + with gr.Row(): + with gr.Column(): + chatbot_left = gr.Image(width=512, label = "Model A") + with gr.Column(): + chatbot_left1 = gr.Image(width=512, label = "Model B") + with gr.Column(): + chatbot_right = gr.Image(width=512, label = "Model C") + with gr.Column(): + chatbot_right1 = gr.Image(width=512, label = "Model D") + + with gr.Row(): + with gr.Column(): + model_selector_left = gr.Markdown("", visible=False) + with gr.Column(): + model_selector_left1 = gr.Markdown("", visible=False) + with gr.Column(): + model_selector_right = gr.Markdown("", visible=False) + with gr.Column(): + model_selector_right1 = gr.Markdown("", visible=False) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(elem_classes="row"): + with gr.Column(scale=1, min_width=10): + leftvote_btn = gr.Button( + value="A is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button" + ) + with gr.Column(scale=1, min_width=10): + left1vote_btn = gr.Button( + value="B is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button" + ) + with gr.Column(scale=1, min_width=10): + rightvote_btn = gr.Button( + value="C is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button" + ) + with gr.Column(scale=1, min_width=10): + right1vote_btn = gr.Button( + value="D is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button" + ) + with gr.Column(scale=1, min_width=10): + tie_btn = gr.Button( + value="🤝 Tie", visible=False, interactive=False, elem_id="btncolor2", elem_classes="best-button" + ) + + + with gr.Row(): + with gr.Blocks(): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + A1_btn = gr.Button( + value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + A2_btn = gr.Button( + value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + A3_btn = gr.Button( + value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + A4_btn = gr.Button( + value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button" + ) + with gr.Blocks(): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + B1_btn = gr.Button( + value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + B2_btn = gr.Button( + value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + B3_btn = gr.Button( + value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + B4_btn = gr.Button( + value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button" + ) + with gr.Blocks(): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + C1_btn = gr.Button( + value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + C2_btn = gr.Button( + value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + C3_btn = gr.Button( + value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + C4_btn = gr.Button( + value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button" + ) + with gr.Blocks(): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + D1_btn = gr.Button( + value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + D2_btn = gr.Button( + value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + D3_btn = gr.Button( + value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button" + ) + with gr.Column(scale=1, min_width=10): + D4_btn = gr.Button( + value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button" + ) + with gr.Row(): + vote_textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your rank (you can use buttons above, or directly type here, e.g. 1 2 3 4)", + container=True, + elem_id="input_box", + visible=False, + ) + vote_submit_btn = gr.Button(value="Submit", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button") + vote_mode_btn = gr.Button(value="🔄 Mode", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button") + + with gr.Row(): + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + # share_btn = gr.Button(value="📷 Share") + with gr.Blocks(): + with gr.Row(elem_id="centered-text"): # + user_info = gr.Markdown("User information (to appear on the contributor leaderboard)", visible=True, elem_id="centered-text") #, elem_id="centered-text" + # with gr.Blocks(): + # name = gr.Markdown("Name", visible=True) + user_name = gr.Textbox(show_label=False,placeholder="👉 Enter your name (optional)", elem_classes="custom-width") + # with gr.Blocks(): + # institution = gr.Markdown("Institution", visible=True) + user_institution = gr.Textbox(show_label=False,placeholder="👉 Enter your affiliation (optional)", elem_classes="custom-width") + + sketch_pad.change( + draw, + inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], + outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], + queue=False, + ) + grounding_instruction.change( + draw, + inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], + outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], + queue=False, + ) + + order_btn_list = [textbox, send_btn, clear_btn, grounding_instruction, sketch_pad, out_imagebox] + vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \ + A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \ + vote_textbox, vote_submit_btn, vote_mode_btn] + + generate_ig0 = gr.Image(width=512, label = "generate A", visible=False, interactive=False) + generate_ig1 = gr.Image(width=512, label = "generate B", visible=False, interactive=False) + generate_ig2 = gr.Image(width=512, label = "generate C", visible=False, interactive=False) + generate_ig3 = gr.Image(width=512, label = "generate D", visible=False, interactive=False) + dummy_left_model = gr.State("") + dummy_left1_model = gr.State("") + dummy_right_model = gr.State("") + dummy_right1_model = gr.State("") + + ig_rank = [None, None, None, None] + bastA_rank = [0, 3, 3, 3] + bastB_rank = [3, 0, 3, 3] + bastC_rank = [3, 3, 0, 3] + bastD_rank = [3, 3, 3, 0] + tie_rank = [0, 0, 0, 0] + bad_rank = [3, 3, 3, 3] + rank = gr.State(ig_rank) + rankA = gr.State(bastA_rank) + rankB = gr.State(bastB_rank) + rankC = gr.State(bastC_rank) + rankD = gr.State(bastD_rank) + rankTie = gr.State(tie_rank) + rankBad = gr.State(bad_rank) + Top1_text = gr.Textbox(value="Top 1", visible=False, interactive=False) + Top2_text = gr.Textbox(value="Top 2", visible=False, interactive=False) + Top3_text = gr.Textbox(value="Top 3", visible=False, interactive=False) + Top4_text = gr.Textbox(value="Top 4", visible=False, interactive=False) + window1_text = gr.Textbox(value="Model A", visible=False, interactive=False) + window2_text = gr.Textbox(value="Model B", visible=False, interactive=False) + window3_text = gr.Textbox(value="Model C", visible=False, interactive=False) + window4_text = gr.Textbox(value="Model D", visible=False, interactive=False) + vote_level = gr.Number(value=0, visible=False, interactive=False) + # Top1_btn.click(reset_level, inputs=[Top1_text], outputs=[vote_level]) + # Top2_btn.click(reset_level, inputs=[Top2_text], outputs=[vote_level]) + # Top3_btn.click(reset_level, inputs=[Top3_text], outputs=[vote_level]) + # Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level]) + vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False) + right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False) + cache_mode = gr.Textbox(value="True", visible=False, interactive=False) + + + textbox.submit( + disable_order_buttons, + inputs=[textbox], + outputs=order_btn_list + ).then( + gen_func, + inputs=[state0, state1, state2, state3, textbox, grounding_instruction, state, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1], + outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \ + model_selector_left, model_selector_left1, model_selector_right, model_selector_right1], + api_name="submit_btn_annony" + ).then( + enable_vote_mode_buttons, + inputs=[vote_mode, textbox], + outputs=vote_order_list + ) + + +if __name__ == "__main__": + with gr.Blocks() as demo: + build_side_by_side_bbox_ui_anony() + demo.launch() \ No newline at end of file diff --git a/serve/leaderboard.py b/serve/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf099be5d96bb26a648731ddc17f8d5e3dbec45 --- /dev/null +++ b/serve/leaderboard.py @@ -0,0 +1,200 @@ +""" +Live monitor of the website statistics and leaderboard. + +Dependency: +sudo apt install pkg-config libicu-dev +pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate +""" + +import argparse +import ast +import pickle +import os +import threading +import time + +import gradio as gr +import numpy as np +import pandas as pd +import json +from datetime import datetime + + +# def make_leaderboard_md(elo_results): +# leaderboard_md = f""" +# # 🏆 Chatbot Arena Leaderboard +# | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +# This leaderboard is based on the following three benchmarks. +# - [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings. +# - [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +# - [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks. + +# 💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023. +# """ +# return leaderboard_md + +def make_leaderboard_md(): + leaderboard_md = f""" +# 🏆 K-Sort Arena Leaderboard (Text-to-Image Generation) +""" + return leaderboard_md + + +def make_leaderboard_video_md(): + leaderboard_md = f""" +# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation) +""" + return leaderboard_md + + +def model_hyperlink(model_name, link): + return f'{model_name}' + + +def make_arena_leaderboard_md(total_models, total_votes, last_updated): + # last_updated = datetime.now() + # last_updated = last_updated.strftime("%Y-%m-%d") + + leaderboard_md = f""" +Total models: **{total_models}** (anonymized), Total votes: **{total_votes}** (equivalent to **{total_votes*6}** pairwise comparisons) +\n Last updated: {last_updated} +""" + + return leaderboard_md + + +def make_disclaimer_md(): + disclaimer_md = ''' + + +

This platform is designed for academic usage, for details please refer to disclaimer.

+ ''' + return disclaimer_md + + +def make_arena_leaderboard_data(results): + import pandas as pd + df = pd.DataFrame(results) + return df + + +def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'): + with open(score_result_file, "r") as json_file: + data = json.load(json_file) + score_results = data["sorted_score_list"] + total_models = data["total_models"] + total_votes = data["total_votes"] + last_updated = data["last_updated"] + + md = make_leaderboard_md() + md_1 = gr.Markdown(md, elem_id="leaderboard_markdown") + + # with gr.Tab("Arena Score", id=0): + md = make_arena_leaderboard_md(total_models, total_votes, last_updated) + gr.Markdown(md, elem_id="leaderboard_markdown") + md = make_arena_leaderboard_data(score_results) + gr.Dataframe(md) + + gr.Markdown( + """ + - Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated. + """, + elem_id="sigma_note_markdown", + ) + + gr.Markdown( + """ ### The leaderboard is regularly updated and continuously incorporates new models. + """, + elem_id="leaderboard_markdown", + ) + with gr.Blocks(): + gr.HTML(make_disclaimer_md) + from .utils import acknowledgment_md, html_code + with gr.Blocks(): + gr.Markdown(acknowledgment_md) + + +def build_leaderboard_video_tab(score_result_file = 'sorted_score_list_video.json'): + with open(score_result_file, "r") as json_file: + data = json.load(json_file) + score_results = data["sorted_score_list"] + total_models = data["total_models"] + total_votes = data["total_votes"] + last_updated = data["last_updated"] + + md = make_leaderboard_video_md() + md_1 = gr.Markdown(md, elem_id="leaderboard_markdown") + # with gr.Blocks(): + # gr.HTML(make_disclaimer_md) + + # with gr.Tab("Arena Score", id=0): + md = make_arena_leaderboard_md(total_models, total_votes, last_updated) + gr.Markdown(md, elem_id="leaderboard_markdown") + md = make_arena_leaderboard_data(score_results) + gr.Dataframe(md) + + notice_markdown_sora = """ + - Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated. + - Note: As Sora's video generation function is not publicly available, we used sample videos from their official website. This may lead to a biased assessment of Sora's capabilities, as these samples likely represent Sora's best outputs. Therefore, Sora's position on our leaderboard should be considered as its upper bound. We are working on methods to conduct more comprehensive and fair comparisons in the future. + """ + + gr.Markdown(notice_markdown_sora, elem_id="notice_markdown_sora") + + gr.Markdown( + """ ### The leaderboard is regularly updated and continuously incorporates new models. + """, + elem_id="leaderboard_markdown", + ) + from .utils import acknowledgment_md, html_code + with gr.Blocks(): + gr.Markdown(acknowledgment_md) + + +def build_leaderboard_contributor(file = 'contributor.json'): + + with open(file, "r") as json_file: + data = json.load(json_file) + score_results = data["contributor"] + last_updated = data["last_updated"] + + md = f""" +# 🏆 Contributor Leaderboard +The submission of user information is entirely optional. This information is used solely for contribution statistics. We respect and safeguard users' privacy choices. +To maintain a clean and concise leaderboard, please ensure consistency in submitted names and affiliations. For example, use 'Berkeley' consistently rather than alternating with 'UC Berkeley'. +- Votes*: Each image vote counts as one Vote*, while each video vote counts as two Votes* due to the increased effort involved. +\n Last updated: {last_updated} +""" + + md_1 = gr.Markdown(md, elem_id="leaderboard_markdown") + + # md = make_arena_leaderboard_md(total_models, total_votes, last_updated) + # gr.Markdown(md, elem_id="leaderboard_markdown") + + md = make_arena_leaderboard_data(score_results) + gr.Dataframe(md) + + gr.Markdown( + """ ### The leaderboard is regularly updated. + """, + elem_id="leaderboard_markdown", + ) \ No newline at end of file diff --git a/serve/log_server.py b/serve/log_server.py new file mode 100644 index 0000000000000000000000000000000000000000..90e040e36ca1a653ac04dc650d1d2f38daaec705 --- /dev/null +++ b/serve/log_server.py @@ -0,0 +1,86 @@ +from fastapi import FastAPI, File, UploadFile, Form, APIRouter +from typing import Optional +import json +import os +import aiofiles +from .log_utils import build_logger +from .constants import LOG_SERVER_SUBDOMAIN, APPEND_JSON, SAVE_IMAGE, SAVE_VIDEO, SAVE_LOG + +logger = build_logger("log_server", "log_server.log", add_remote_handler=False) + +app = APIRouter(prefix=LOG_SERVER_SUBDOMAIN) + +@app.post(f"/{APPEND_JSON}") +async def append_json(json_str: str = Form(...), file_name: str = Form(...)): + """ + Appends a JSON string to a specified file. + """ + # Convert the string back to a JSON object (dict) + data = json.loads(json_str) + # Append the data to the specified file + if os.path.dirname(file_name): + os.makedirs(os.path.dirname(file_name), exist_ok=True) + async with aiofiles.open(file_name, mode='a') as f: + await f.write(json.dumps(data) + "\n") + + logger.info(f"Appended 1 JSON object to {file_name}") + return {"message": "JSON data appended successfully"} + +@app.post(f"/{SAVE_IMAGE}") +async def save_image(image: UploadFile = File(...), image_path: str = Form(...)): + """ + Saves an uploaded image to the specified path. + """ + # Note: 'image_path' should include the file name and extension for the image to be saved. + if os.path.dirname(image_path): + os.makedirs(os.path.dirname(image_path), exist_ok=True) + async with aiofiles.open(image_path, mode='wb') as f: + content = await image.read() # Read the content of the uploaded image + await f.write(content) # Write the image content to a file + logger.info(f"Image saved successfully at {image_path}") + return {"message": f"Image saved successfully at {image_path}"} + +@app.post(f"/{SAVE_VIDEO}") +async def save_video(video: UploadFile = File(...), video_path: str = Form(...)): + """ + Saves an uploaded video to the specified path. + """ + # Note: 'video_path' should include the file name and extension for the video to be saved. + if os.path.dirname(video_path): + os.makedirs(os.path.dirname(video_path), exist_ok=True) + async with aiofiles.open(video_path, mode='wb') as f: + content = await video.read() # Read the content of the uploaded video + await f.write(content) # Write the video content to a file + logger.info(f"Video saved successfully at {video_path}") + return {"message": f"Image saved successfully at {video_path}"} + +@app.post(f"/{SAVE_LOG}") +async def save_log(message: str = Form(...), log_path: str = Form(...)): + """ + Save a log message to a specified log file on the server. + """ + # Ensure the directory for the log file exists + if os.path.dirname(log_path): + os.makedirs(os.path.dirname(log_path), exist_ok=True) + + # Append the log message to the specified log file + async with aiofiles.open(log_path, mode='a') as f: + await f.write(f"{message}\n") + + logger.info(f"Romote log message saved to {log_path}") + return {"message": f"Log message saved successfully to {log_path}"} + + +@app.get(f"/read_file") +async def read_file(file_name: str): + """ + Reads the content of a specified file and returns it. + """ + if not os.path.exists(file_name): + return {"message": f"File {file_name} does not exist."} + + async with aiofiles.open(file_name, mode='r') as f: + content = await f.read() + + logger.info(f"Read file {file_name}") + return {"file_name": file_name, "content": content} diff --git a/serve/log_utils.py b/serve/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00594acd7f951b19eb5852bc66df0a0eca40136e --- /dev/null +++ b/serve/log_utils.py @@ -0,0 +1,142 @@ +""" +Common utilities. +""" +from asyncio import AbstractEventLoop +import json +import logging +import logging.handlers +import os +import platform +import sys +from typing import AsyncGenerator, Generator +import warnings +from pathlib import Path + +import requests + +from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG +from .utils import save_log_str_on_log_server + + +handler = None +visited_loggers = set() + + +# Assuming LOGDIR and other necessary imports and global variables are defined + +class APIHandler(logging.Handler): + """Custom logging handler that sends logs to an API.""" + + def __init__(self, apiUrl, log_path, *args, **kwargs): + super(APIHandler, self).__init__(*args, **kwargs) + self.apiUrl = apiUrl + self.log_path = log_path + + def emit(self, record): + log_entry = self.format(record) + try: + save_log_str_on_log_server(log_entry, self.log_path) + except requests.RequestException as e: + print(f"Error sending log to API: {e}", file=sys.stderr) + +def build_logger(logger_name, logger_filename, add_remote_handler=False): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + if add_remote_handler: + # Add APIHandler to send logs to your API + api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}" + + remote_logger_filename = str(Path(logger_filename).stem + "_remote.log") + api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}") + api_handler.setFormatter(formatter) + logger.addHandler(api_handler) + + stdout_logger.addHandler(api_handler) + stderr_logger.addHandler(api_handler) + + # if LOGDIR is empty, then don't try output log to local file + if LOGDIR != "": + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for l in [stdout_logger, stderr_logger, logger]: + if l in visited_loggers: + continue + visited_loggers.add(l) + l.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" \ No newline at end of file diff --git a/serve/update_skill.py b/serve/update_skill.py new file mode 100644 index 0000000000000000000000000000000000000000..36ad7cddfd1538aaa5fb4b591911faa1e116220d --- /dev/null +++ b/serve/update_skill.py @@ -0,0 +1,119 @@ +import numpy as np +import json +from trueskill import TrueSkill +import paramiko +import io, os +import sys +from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL +trueskill_env = TrueSkill() +sys.path.append('../') +from model.models import IMAGE_GENERATION_MODELS + + +ssh_skill_client = None +sftp_skill_client = None + + +def create_ssh_skill_client(server, port, user, password): + global ssh_skill_client, sftp_skill_client + ssh_skill_client = paramiko.SSHClient() + ssh_skill_client.load_system_host_keys() + ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_skill_client.connect(server, port, user, password) + + transport = ssh_skill_client.get_transport() + transport.set_keepalive(60) + + sftp_skill_client = ssh_skill_client.open_sftp() + + +def is_connected(): + global ssh_skill_client, sftp_skill_client + if ssh_skill_client is None or sftp_skill_client is None: + return False + if not ssh_skill_client.get_transport().is_active(): + return False + try: + sftp_skill_client.listdir('.') + except Exception as e: + print(f"Error checking SFTP connection: {e}") + return False + return True + + +def ucb_score(trueskill_diff, t, n): + exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) + ucb = -trueskill_diff + 1.0 * exploration_term + return ucb + + +def update_trueskill(ratings, ranks): + new_ratings = trueskill_env.rate(ratings, ranks) + return new_ratings + + +def serialize_rating(rating): + return {'mu': rating.mu, 'sigma': rating.sigma} + + +def deserialize_rating(rating_dict): + return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) + + +def save_json_via_sftp(ratings, comparison_counts, total_comparisons): + global sftp_skill_client + if not is_connected(): + create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + data = { + 'ratings': [serialize_rating(r) for r in ratings], + 'comparison_counts': comparison_counts.tolist(), + 'total_comparisons': total_comparisons + } + json_data = json.dumps(data) + with sftp_skill_client.open(SSH_SKILL, 'w') as f: + f.write(json_data) + + +def load_json_via_sftp(): + global sftp_skill_client + if not is_connected(): + create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + with sftp_skill_client.open(SSH_SKILL, 'r') as f: + data = json.load(f) + ratings = [deserialize_rating(r) for r in data['ratings']] + comparison_counts = np.array(data['comparison_counts']) + total_comparisons = data['total_comparisons'] + return ratings, comparison_counts, total_comparisons + + +def update_skill(rank, model_names, k_group=4): + + ratings, comparison_counts, total_comparisons = load_json_via_sftp() + + # group = Model_ID.group + group = [] + for model_name in model_names: + group.append(IMAGE_GENERATION_MODELS.index(model_name)) + print(group) + + pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))] + for player1, player2 in pairwise_comparisons: + if rank[player1] < rank[player2]: + ranks = [0, 1] + updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) + ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] + elif rank[player1] > rank[player2]: + ranks = [1, 0] + updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) + ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] + + comparison_counts[group[player1], group[player2]] += 1 + comparison_counts[group[player2], group[player1]] += 1 + + total_comparisons += 1 + + save_json_via_sftp(ratings, comparison_counts, total_comparisons) + + from model.matchmaker import RunningPivot + if group[0] in RunningPivot.running_pivot: + RunningPivot.running_pivot.remove(group[0]) \ No newline at end of file diff --git a/serve/update_skill_video.py b/serve/update_skill_video.py new file mode 100644 index 0000000000000000000000000000000000000000..dff7ab985d51c08841a8b6c2f18da9e4b50792a6 --- /dev/null +++ b/serve/update_skill_video.py @@ -0,0 +1,115 @@ +import numpy as np +import json +from trueskill import TrueSkill +import paramiko +import io, os +import sys +from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL +trueskill_env = TrueSkill() +sys.path.append('../') +from model.models import VIDEO_GENERATION_MODELS + + +ssh_skill_client = None +sftp_skill_client = None + + +def create_ssh_skill_client(server, port, user, password): + global ssh_skill_client, sftp_skill_client + ssh_skill_client = paramiko.SSHClient() + ssh_skill_client.load_system_host_keys() + ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_skill_client.connect(server, port, user, password) + + transport = ssh_skill_client.get_transport() + transport.set_keepalive(60) + + sftp_skill_client = ssh_skill_client.open_sftp() + + +def is_connected(): + global ssh_skill_client, sftp_skill_client + if ssh_skill_client is None or sftp_skill_client is None: + return False + if not ssh_skill_client.get_transport().is_active(): + return False + try: + sftp_skill_client.listdir('.') + except Exception as e: + print(f"Error checking SFTP connection: {e}") + return False + return True + + +def ucb_score(trueskill_diff, t, n): + exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) + ucb = -trueskill_diff + 1.0 * exploration_term + return ucb + + +def update_trueskill(ratings, ranks): + new_ratings = trueskill_env.rate(ratings, ranks) + return new_ratings + + +def serialize_rating(rating): + return {'mu': rating.mu, 'sigma': rating.sigma} + + +def deserialize_rating(rating_dict): + return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) + + +def save_json_via_sftp(ratings, comparison_counts, total_comparisons): + global sftp_skill_client + if not is_connected(): + create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + data = { + 'ratings': [serialize_rating(r) for r in ratings], + 'comparison_counts': comparison_counts.tolist(), + 'total_comparisons': total_comparisons + } + json_data = json.dumps(data) + with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f: + f.write(json_data) + + +def load_json_via_sftp(): + global sftp_skill_client + if not is_connected(): + create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + with sftp_skill_client.open(SSH_VIDEO_SKILL, 'r') as f: + data = json.load(f) + ratings = [deserialize_rating(r) for r in data['ratings']] + comparison_counts = np.array(data['comparison_counts']) + total_comparisons = data['total_comparisons'] + return ratings, comparison_counts, total_comparisons + + +def update_skill_video(rank, model_names, k_group=4): + + ratings, comparison_counts, total_comparisons = load_json_via_sftp() + + # group = Model_ID.group + group = [] + for model_name in model_names: + group.append(VIDEO_GENERATION_MODELS.index(model_name)) + print(group) + + pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))] + for player1, player2 in pairwise_comparisons: + if rank[player1] < rank[player2]: + ranks = [0, 1] + updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) + ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] + elif rank[player1] > rank[player2]: + ranks = [1, 0] + updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) + ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] + + comparison_counts[group[player1], group[player2]] += 1 + comparison_counts[group[player2], group[player1]] += 1 + + total_comparisons += 1 + + save_json_via_sftp(ratings, comparison_counts, total_comparisons) \ No newline at end of file diff --git a/serve/upload.py b/serve/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..692d183696c3735b9a3bf19f90fa961d0734c442 --- /dev/null +++ b/serve/upload.py @@ -0,0 +1,278 @@ +import paramiko +import numpy as np +import io, os, stat +import gradio as gr +from PIL import Image +import requests +import json +import random +import concurrent.futures +from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO + + +ssh_client = None +sftp_client = None +sftp_client_imgs = None + + +def open_sftp(i=0): + global ssh_client + sftp_client = ssh_client.open_sftp() + return sftp_client + + +def create_ssh_client(server, port, user, password): + global ssh_client, sftp_client, sftp_client_imgs + ssh_client = paramiko.SSHClient() + ssh_client.load_system_host_keys() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_client.connect(server, port, user, password) + + transport = ssh_client.get_transport() + transport.set_keepalive(60) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(open_sftp, i) for i in range(5)] + results = [future.result() for future in futures] + + sftp_client = results[0] + sftp_client_imgs = results[1:] + + +def is_connected(): + global ssh_client, sftp_client + if ssh_client is None or sftp_client is None: + return False + if not ssh_client.get_transport().is_active(): + return False + try: + sftp_client.listdir('.') + except Exception as e: + print(f"Error checking SFTP connection: {e}") + return False + return True + + +def get_image_from_url(image_url): + response = requests.get(image_url) + response.raise_for_status() # success + return Image.open(io.BytesIO(response.content)) + + +# def get_random_mscoco_prompt(): +# global sftp_client +# if not is_connected(): +# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) +# num = random.randint(0, 2999) +# file = "{}.txt".format(num) + +# remote_file_path = os.path.join(SSH_MSCOCO, file) +# with sftp_client.file(remote_file_path, 'r') as f: +# content = f.read().decode('utf-8') +# print(f"Content of {file}:") +# print("\n") +# return content + + +def get_random_mscoco_prompt(): + + file_path = './coco_prompt.txt' + with open(file_path, 'r') as file: + lines = file.readlines() + + random_line = random.choice(lines).strip() + return random_line + + +def get_random_video_prompt(root_dir): + subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] + if not subdirs: + raise NotImplementedError + selected_dir = random.choice(subdirs) + prompt_path = os.path.join(selected_dir, 'prompt.txt') + + if os.path.exists(prompt_path): + str_list = [] + with open(prompt_path, 'r', encoding='utf-8') as file: + for line in file: + str_list.append(line.strip()) + prompt = str_list[0] + else: + raise NotImplementedError + return selected_dir, prompt + + +def get_ssh_random_video_prompt(root_dir, local_dir, model_names): + def is_directory(sftp, path): + try: + return stat.S_ISDIR(sftp.stat(path).st_mode) + except IOError: + return False + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + sftp = ssh.open_sftp() + + remote_subdirs = sftp.listdir(root_dir) + remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] + + if not remote_subdirs: + print(f"No subdirectories found in {root_dir}") + raise NotImplementedError + + chosen_subdir = random.choice(remote_subdirs) + chosen_subdir_path = os.path.join(root_dir, chosen_subdir) + print(f"Chosen subdirectory: {chosen_subdir_path}") + + prompt_path = 'prompt.txt' + results = [prompt_path] + for name in model_names: + model_source, model_name, model_type = name.split("_") + video_path = f'{model_name}.mp4' + print(video_path) + results.append(video_path) + + local_path = [] + for tar_file in results: + remote_file_path = os.path.join(chosen_subdir_path, tar_file) + local_file_path = os.path.join(local_dir, tar_file) + sftp.get(remote_file_path, local_file_path) + local_path.append(local_file_path) + print(f"Downloaded {remote_file_path} to {local_file_path}") + + if os.path.exists(local_path[0]): + str_list = [] + with open(local_path[0], 'r', encoding='utf-8') as file: + for line in file: + str_list.append(line.strip()) + prompt = str_list[0] + else: + raise NotImplementedError + except Exception as e: + print(f"An error occurred: {e}") + raise NotImplementedError + sftp.close() + ssh.close() + return prompt, local_path[1:] + + +def get_ssh_random_image_prompt(root_dir, local_dir, model_names): + def is_directory(sftp, path): + try: + return stat.S_ISDIR(sftp.stat(path).st_mode) + except IOError: + return False + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + sftp = ssh.open_sftp() + + remote_subdirs = sftp.listdir(root_dir) + remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] + + if not remote_subdirs: + print(f"No subdirectories found in {root_dir}") + raise NotImplementedError + + chosen_subdir = random.choice(remote_subdirs) + chosen_subdir_path = os.path.join(root_dir, chosen_subdir) + print(f"Chosen subdirectory: {chosen_subdir_path}") + + prompt_path = 'prompt.txt' + results = [prompt_path] + for name in model_names: + model_source, model_name, model_type = name.split("_") + image_path = f'{model_name}.jpg' + print(image_path) + results.append(image_path) + + local_path = [] + for tar_file in results: + remote_file_path = os.path.join(chosen_subdir_path, tar_file) + local_file_path = os.path.join(local_dir, tar_file) + sftp.get(remote_file_path, local_file_path) + local_path.append(local_file_path) + print(f"Downloaded {remote_file_path} to {local_file_path}") + + if os.path.exists(local_path[0]): + str_list = [] + with open(local_path[0], 'r', encoding='utf-8') as file: + for line in file: + str_list.append(line.strip()) + prompt = str_list[0] + else: + raise NotImplementedError + except Exception as e: + print(f"An error occurred: {e}") + raise NotImplementedError + sftp.close() + ssh.close() + return prompt, [Image.open(path) for path in local_path[1:]] + + +def create_remote_directory(remote_directory, video=False): + global ssh_client + if not is_connected(): + create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + if video: + log_dir = f'{SSH_VIDEO_LOG}/{remote_directory}' + else: + log_dir = f'{SSH_LOG}/{remote_directory}' + stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {log_dir}') + error = stderr.read().decode('utf-8') + if error: + print(f"Error: {error}") + else: + print(f"Directory {remote_directory} created successfully.") + return log_dir + + +def upload_images(i, image_list, output_file_list, sftp_client): + with sftp_client as sftp: + if isinstance(image_list[i], str): + print("get url image") + image_list[i] = get_image_from_url(image_list[i]) + with io.BytesIO() as image_byte_stream: + image_list[i] = image_list[i].resize((512, 512), Image.ANTIALIAS) + image_list[i].save(image_byte_stream, format='JPEG') + image_byte_stream.seek(0) + sftp.putfo(image_byte_stream, output_file_list[i]) + print(f"Successfully uploaded image to {output_file_list[i]}") + + +def upload_ssh_all(states, output_dir, data, data_path): + global sftp_client + global sftp_client_imgs + if not is_connected(): + create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + output_file_list = [] + image_list = [] + for i in range(len(states)): + output_file = os.path.join(output_dir, f"{i}.jpg") + output_file_list.append(output_file) + image_list.append(states[i].output) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))] + + with sftp_client as sftp: + json_data = json.dumps(data, indent=4) + with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: + sftp.putfo(json_byte_stream, data_path) + print(f"Successfully uploaded JSON data to {data_path}") + # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + + +def upload_ssh_data(data, data_path): + global sftp_client + global sftp_client_imgs + if not is_connected(): + create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) + + with sftp_client as sftp: + json_data = json.dumps(data, indent=4) + with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: + sftp.putfo(json_byte_stream, data_path) + print(f"Successfully uploaded JSON data to {data_path}") diff --git a/serve/utils.py b/serve/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..afeafe4103f2a500471dbc1aaedb8dac20732444 --- /dev/null +++ b/serve/utils.py @@ -0,0 +1,298 @@ +import os +import json +import datetime +import requests +import numpy as np +import gradio as gr +from pathlib import Path +from model.model_registry import * +from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_VIDEO, SAVE_LOG +from typing import Union + + +enable_btn = gr.update(interactive=True, visible=True) +disable_btn = gr.update(interactive=False) +invisible_btn = gr.update(interactive=False, visible=False) +no_change_btn = gr.update(value="No Change", interactive=True, visible=True) + + +def build_about(): + about_markdown = f""" +# About Us + +""" + + gr.Markdown(about_markdown, elem_id="about_markdown") + + +acknowledgment_md = """ +### Acknowledgment +
+

Our codebase is built upon FastChat, ImagenHub.

+
+""" +# 定义一个HTML组件来创建链接和处理点击事件 +html_code = """ +

This platform is designed for academic usage, for details please refer to disclaimer.

+ + +""" +block_css = """ +#notice_markdown { + font-size: 110% +} +#notice_markdown th { + display: none; +} +#notice_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#model_description_markdown { + font-size: 110% +} +#leaderboard_markdown { + font-size: 110% +} +#leaderboard_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#leaderboard_dataframe td { + line-height: 0.1em; +} +#about_markdown { + font-size: 110% +} +#ack_markdown { + font-size: 110% +} +#input_box textarea { +} +footer { + display:none !important +} +.image-about img { + margin: 0 30px; + margin-top: 30px; + height: 60px; + max-height: 100%; + width: auto; + float: left; +.input-image, .image-preview { + margin: 0 30px; + height: 30px; + max-height: 100%; + width: auto; + max-width: 30%;} +} + +.custom-button { + border-radius: 8px; +} +.best-button { + border-radius: 8px; +} +.row { + display: flex; + justify-content: space-between; +} +.send-button { + background: rgb(168, 230, 207); + color: rgb(0, 198, 171); +} +.submit-button { + color: red; +} +#custom-width {width: 100px !important;} +#centered-row { + display: flex; + justify-content: center; +} +#btncolor1 {background: rgb(168, 230, 207);} +#btncolor2 {background: rgb(253, 255, 171);} +#btncolor3 {background: rgb(255, 211, 182);} +#btncolor4 {background: rgb(255, 170, 165);} + +#btnblue {background: linear-gradient(to bottom right, rgb(222, 235, 247), rgb(189,215,238)); color: rgb(0, 112, 192); border: 1px solid rgb(189,215,238);} +#btnpink {background: rgb(255, 168, 184);} +#centered-text { display: flex; justify-content: center; align-items: center; height: 100%; width: 100%; font-size: 150%; } + +""" + +# + +#btncolor1 {background: rgb(128, 214, 255);} +#btncolor2 {background: rgb(237, 247, 152);} +#btncolor3 {background: rgb(250, 181, 122);} +#btncolor4 {background: rgb(240, 104, 104);} + +#btncolor1 {background: rgb(112, 161, 215);} +#btncolor2 {background: rgb(161, 222, 147);} +#btncolor3 {background: rgb(247, 244, 139);} +#btncolor4 {background: rgb(244, 124, 124);} + +#btncolor1 {background: rgb(168, 230, 207);} +#btncolor2 {background: rgb(253, 255, 171);} +#btncolor3 {background: rgb(255, 211, 182);} +#btncolor4 {background: rgb(255, 170, 165);} + +#btncolor1 {background: rgb(255, 212, 96);} +#btncolor2 {background: rgb(240, 123, 63);} +#btncolor3 {background: rgb(234, 84, 85);} +#btncolor4 {background: rgb(45, 64, 89);} + +#btncolor1 {background: rgb(255, 189, 57);} +#btncolor2 {background: rgb(230, 28, 93);} +#btncolor3 {background: rgb(147, 0, 119);} +#btncolor4 {background: rgb(58, 0, 136);} +# max-width: 100px; +# .custom-button { +# padding: 10px 15px; +# text-align: center; +# text-decoration: none; +# display: inline-block; +# font-size: 16px; +# cursor: pointer; +# border-radius: 8px; +# } +# { +# background-color: green; /* 背景颜色 */ +# color: white; /* 文字颜色 */ +# border: none; /* 无边框 */ +# padding: 10px 20px; /* 内边距 */ +# text-align: center; /* 文本居中 */ +# text-decoration: none; /* 无下划线 */ +# display: inline-block; /* 行内块 */ +# font-size: 16px; /* 字体大小 */ +# margin: 4px 2px; /* 外边距 */ +# cursor: pointer; /* 鼠标指针 */ +# border-radius: 5px; /* 圆角边框 */ +# } +# .custom-button:hover { +# background-color: darkgreen; /* 悬停时的背景颜色 */ +# } +def enable_loop_buttons(): + return tuple(gr.update(loop=True) for i in range(4)) +def enable_vote_buttons(): + return tuple(gr.update(visible=True, interactive=i<=4) for i in range(6)) +def disable_vote_buttons(): + return tuple(gr.update(visible=False, interactive=False) for i in range(6)) +def disable_vote(): + return tuple(gr.update(interactive=False) for i in range(18)) +def enable_vote_mode_buttons(mode, textbox): + print(mode) + if not textbox.strip(): + return tuple(gr.update(visible=False, interactive=False) for _ in range(24)) + else: + if mode == "Best": + return (gr.update(visible=True, interactive=True),) * 5 + \ + (gr.update(visible=False, interactive=False),) * 16 + \ + (gr.update(visible=True, interactive=False),) * 2 + \ + (gr.update(visible=True, interactive=True),) + elif mode == "Rank": + return (gr.update(visible=False, interactive=False),) * 5 + \ + (gr.update(visible=True, interactive=True),) * 16 + \ + (gr.update(visible=True, interactive=True),) * 3 +def disable_vote_mode_buttons(): + return tuple(gr.update(visible=False, interactive=False) for _ in range(24)) + + +def enable_order_buttons(): + return tuple(gr.update(interactive=True) for _ in range(5)) +def disable_order_buttons(textbox, cache="False"): + if cache=="True": + return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=True)) + if not textbox.strip(): + return (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)) + else: + return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)) + +def enable_video_order_buttons(): + return tuple(gr.update(interactive=True) for _ in range(4)) +def disable_video_order_buttons(): + return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True)) + +def clear_history(): + return None, "", None + +def clear_history_side_by_side(): + return None, None, "", None, None + +def clear_history_side_by_side_anony(): + return None, None, None, None, gr.update(visible=True, interactive=True, value=""), gr.update(visible=True, interactive=True, value=""), None, None, None, None, \ + gr.Markdown("", visible=False), gr.Markdown("", visible=False), gr.Markdown("", visible=False), gr.Markdown("", visible=False) + +def clear_history_ie(): + return None, "", "", "", None, None + +def clear_history_side_by_side_ie(): + return None, None, "", "", "", None, None, None + +def clear_history_side_by_side_ie_anony(): + return None, None, "", "", "", None, None, None, gr.Markdown("", visible=False), gr.Markdown("", visible=False) + +def get_ip(request: gr.Request): + if request: + if "cf-connecting-ip" in request.headers: + ip = request.headers["cf-connecting-ip"] or request.client.host + else: + ip = request.client.host + else: + ip = None + return ip + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + +def save_image_file_on_log_server(image_file:str): + return 1 + + image_file = Path(image_file).absolute().relative_to(os.getcwd()) + image_file = str(image_file) + # Open the image file in binary mode + url = f"{LOG_SERVER_ADDR}/{SAVE_IMAGE}" + with open(image_file, 'rb') as f: + # Make the POST request, sending the image file and the image path + response = requests.post(url, files={'image': f}, data={'image_path': image_file}) + return response + +def save_video_file_on_log_server(video_file:str): + return 1 + + video_file = Path(video_file).absolute().relative_to(os.getcwd()) + video_file = str(video_file) + # Open the video file in binary mode + url = f"{LOG_SERVER_ADDR}/{SAVE_VIDEO}" + with open(video_file, 'rb') as f: + # Make the POST request, sending the video file and the video path + response = requests.post(url, files={'video': f}, data={'video_path': video_file}) + return response + +def append_json_item_on_log_server(json_item: Union[dict, str], log_file: str): + return 1 + + if isinstance(json_item, dict): + json_item = json.dumps(json_item) + log_file = Path(log_file).absolute().relative_to(os.getcwd()) + log_file = str(log_file) + url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}" + # Make the POST request, sending the JSON string and the log file name + response = requests.post(url, data={'json_str': json_item, 'file_name': log_file}) + return response + +def save_log_str_on_log_server(log_str: str, log_file: str): + return 1 + + log_file = Path(log_file).absolute().relative_to(os.getcwd()) + log_file = str(log_file) + url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}" + # Make the POST request, sending the log message and the log file name + response = requests.post(url, data={'message': log_str, 'log_path': log_file}) + return response \ No newline at end of file diff --git a/serve/vote_utils.py b/serve/vote_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7179fbdeb773570ef0c3fbad6c0db0ea42052c7 --- /dev/null +++ b/serve/vote_utils.py @@ -0,0 +1,1686 @@ +import datetime +import time +import json +import uuid +import gradio as gr +import regex as re +from pathlib import Path +from .utils import * +from .log_utils import build_logger +from .constants import IMAGE_DIR, VIDEO_DIR +import imageio +from diffusers.utils import load_image + +ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat +igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle +ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat +iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle +vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat +vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle + +def save_any_image(image_file, file_path): + if isinstance(image_file, str): + image = load_image(image_file) + image.save(file_path, 'JPEG') + else: + image_file.save(file_path, 'JPEG') + +def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + +def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + for state in states: + print(state.conv_id) + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + +def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' + source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + with open(source_file, 'w') as sf: + save_any_image(state.source_image, sf) + save_image_file_on_log_server(output_file) + save_image_file_on_log_server(source_file) + +def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + for state in states: + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' + source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + with open(source_file, 'w') as sf: + save_any_image(state.source_image, sf) + save_image_file_on_log_server(output_file) + save_image_file_on_log_server(source_file) + + +def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + if state.model_name.startswith('fal'): + r = requests.get(state.output) + with open(output_file, 'wb') as outfile: + outfile.write(r.content) + else: + print("======== video shape: ========") + print(state.output.shape) + imageio.mimwrite(output_file, state.output, fps=8, quality=9) + save_video_file_on_log_server(output_file) + + + +def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + for state in states: + output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + if state.model_name.startswith('fal'): + r = requests.get(state.output) + with open(output_file, 'wb') as outfile: + outfile.write(r.content) + else: + print("======== video shape: ========") + print(state.output.shape) + imageio.mimwrite(output_file, state.output, fps=8, quality=9) + save_video_file_on_log_server(output_file) + + +## Image Generation (IG) Single Model Direct Chat +def upvote_last_response_ig(state, model_selector, request: gr.Request): + ip = get_ip(request) + ig_logger.info(f"upvote. ip: {ip}") + vote_last_response_ig(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + +def downvote_last_response_ig(state, model_selector, request: gr.Request): + ip = get_ip(request) + ig_logger.info(f"downvote. ip: {ip}") + vote_last_response_ig(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response_ig(state, model_selector, request: gr.Request): + ip = get_ip(request) + ig_logger.info(f"flag. ip: {ip}") + vote_last_response_ig(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + +## Image Generation Multi (IGM) Side-by-Side and Battle + +def leftvote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "leftvote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) +def left1vote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"left1vote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "left1vote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) +def rightvote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "rightvote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + # print(model_selector0) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) +def right1vote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"right1vote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "right1vote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + # print(model_selector0) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) + +def tievote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "tievote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) + + +def bothbad_vote_last_response_igm( + state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, request: gr.Request +): + igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response_igm( + [state0, state1, state2, state3], "bothbad_vote", [model_selector0, model_selector1, model_selector2, model_selector3], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) + ) + else: + return ("",) + (disable_btn,) * 6 + ( + gr.Markdown(state0.model_name, visible=True), + gr.Markdown(state1.model_name, visible=True), + gr.Markdown(state2.model_name, visible=True), + gr.Markdown(state3.model_name, visible=True) + ) + +## Image Editing (IE) Single Model Direct Chat + +def upvote_last_response_ie(state, model_selector, request: gr.Request): + ip = get_ip(request) + ie_logger.info(f"upvote. ip: {ip}") + vote_last_response_ie(state, "upvote", model_selector, request) + return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 + +def downvote_last_response_ie(state, model_selector, request: gr.Request): + ip = get_ip(request) + ie_logger.info(f"downvote. ip: {ip}") + vote_last_response_ie(state, "downvote", model_selector, request) + return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 + +def flag_last_response_ie(state, model_selector, request: gr.Request): + ip = get_ip(request) + ie_logger.info(f"flag. ip: {ip}") + vote_last_response_ie(state, "flag", model_selector, request) + return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 + +## Image Editing Multi (IEM) Side-by-Side and Battle +def leftvote_last_response_iem( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}") + vote_last_response_iem( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + # names = ( + # "### Model A: " + state0.model_name, + # "### Model B: " + state1.model_name, + # ) + # names = (state0.model_name, state1.model_name) + if model_selector0 == "": + names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) + return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 + +def rightvote_last_response_iem( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}") + vote_last_response_iem( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + # names = ( + # "### Model A: " + state0.model_name, + # "### Model B: " + state1.model_name, + # ) + if model_selector0 == "": + names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) + return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 + +def tievote_last_response_iem( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + iem_logger.info(f"tievote (anony). ip: {get_ip(request)}") + vote_last_response_iem( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) + return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 + +def bothbad_vote_last_response_iem( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + vote_last_response_iem( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) + return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 + + +## Video Generation (VG) Single Model Direct Chat +def upvote_last_response_vg(state, model_selector, request: gr.Request): + ip = get_ip(request) + vg_logger.info(f"upvote. ip: {ip}") + vote_last_response_vg(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + +def downvote_last_response_vg(state, model_selector, request: gr.Request): + ip = get_ip(request) + vg_logger.info(f"downvote. ip: {ip}") + vote_last_response_vg(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response_vg(state, model_selector, request: gr.Request): + ip = get_ip(request) + vg_logger.info(f"flag. ip: {ip}") + vote_last_response_vg(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + +## Image Generation Multi (IGM) Side-by-Side and Battle + +def leftvote_last_response_vgm( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response_vgm( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(state0.model_name, visible=False), + gr.Markdown(state1.model_name, visible=False)) + + +def rightvote_last_response_vgm( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response_vgm( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(state0.model_name, visible=False), + gr.Markdown(state1.model_name, visible=False)) + +def tievote_last_response_vgm( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + vgm_logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response_vgm( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(state0.model_name, visible=False), + gr.Markdown(state1.model_name, visible=False)) + + +def bothbad_vote_last_response_vgm( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response_vgm( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + if model_selector0 == "": + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), + gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) + else: + return ("",) + (disable_btn,) * 4 + ( + gr.Markdown(state0.model_name, visible=False), + gr.Markdown(state1.model_name, visible=False)) + +share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" +def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request): + igm_logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response_igm( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + +def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request): + iem_logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response_iem( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + +## All Generation Gradio Interface + +class ImageStateIG: + def __init__(self, model_name): + self.conv_id = uuid.uuid4().hex + self.model_name = model_name + self.prompt = None + self.output = None + + def dict(self): + base = { + "conv_id": self.conv_id, + "model_name": self.model_name, + "prompt": self.prompt + } + return base + +class ImageStateIE: + def __init__(self, model_name): + self.conv_id = uuid.uuid4().hex + self.model_name = model_name + self.source_prompt = None + self.target_prompt = None + self.instruct_prompt = None + self.source_image = None + self.output = None + + def dict(self): + base = { + "conv_id": self.conv_id, + "model_name": self.model_name, + "source_prompt": self.source_prompt, + "target_prompt": self.target_prompt, + "instruct_prompt": self.instruct_prompt + } + return base + +class VideoStateVG: + def __init__(self, model_name): + self.conv_id = uuid.uuid4().hex + self.model_name = model_name + self.prompt = None + self.output = None + + def dict(self): + base = { + "conv_id": self.conv_id, + "model_name": self.model_name, + "prompt": self.prompt + } + return base + + +def generate_ig(gen_func, state, text, model_name, request: gr.Request): + if not text: + raise gr.Warning("Prompt cannot be empty.") + if not model_name: + raise gr.Warning("Model name cannot be empty.") + if state is None: + state = ImageStateIG(model_name) + ip = get_ip(request) + ig_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + generated_image = gen_func(text, model_name) + state.prompt = text + state.output = generated_image + state.model_name = model_name + + yield state, generated_image + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + +def generate_ig_museum(gen_func, state, model_name, request: gr.Request): + if not model_name: + raise gr.Warning("Model name cannot be empty.") + if state is None: + state = ImageStateIG(model_name) + ip = get_ip(request) + ig_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + generated_image, text = gen_func(model_name) + state.prompt = text + state.output = generated_image + state.model_name = model_name + + yield state, generated_image, text + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + +def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): + if not text: + raise gr.Warning("Prompt cannot be empty.") + if not model_name0: + raise gr.Warning("Model name A cannot be empty.") + if not model_name1: + raise gr.Warning("Model name B cannot be empty.") + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + # Remove ### Model (A|B): from model name + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + generated_image0, generated_image1 = gen_func(text, model_name0, model_name1) + state0.prompt = text + state1.prompt = text + state0.output = generated_image0 + state1.output = generated_image1 + state0.model_name = model_name0 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1 + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + +def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): + if not model_name0: + raise gr.Warning("Model name A cannot be empty.") + if not model_name1: + raise gr.Warning("Model name B cannot be empty.") + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + # Remove ### Model (A|B): from model name + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + generated_image0, generated_image1, text = gen_func(model_name0, model_name1) + state0.prompt = text + state1.prompt = text + state0.output = generated_image0 + state1.output = generated_image1 + state0.model_name = model_name0 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1, text + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(output_file) + + +def generate_igm_annoy(gen_func, state0, state1, state2, state3, text, model_name0, model_name1, model_name2, model_name3, request: gr.Request): + if not text.strip(): + return (gr.update(visible=False),) * 16 + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + if state2 is None: + state2 = ImageStateIG(model_name2) + if state3 is None: + state3 = ImageStateIG(model_name3) + + + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + model_name2 = "" + model_name3 = "" + + generated_image0, generated_image1, generated_image2, generated_image3, model_name0, model_name1, model_name2, model_name3 \ + = gen_func(text, model_name0, model_name1, model_name2, model_name3) + state0.prompt = text + state1.prompt = text + state2.prompt = text + state3.prompt = text + + state0.output = generated_image0 + state1.output = generated_image1 + state2.output = generated_image2 + state3.output = generated_image3 + + state0.model_name = model_name0 + state1.model_name = model_name1 + state2.model_name = model_name2 + state3.model_name = model_name3 + + yield state0, state1, state2, state3, generated_image0, generated_image1, generated_image2, generated_image3, \ + generated_image0, generated_image1, generated_image2, generated_image3, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False), \ + gr.Markdown(f"### Model C: {model_name2}", visible=False), gr.Markdown(f"### Model D: {model_name3}", visible=False) + + # finish_tstamp = time.time() + # # logger.info(f"===output===: {output}") + + # with open(get_conv_log_filename(), "a") as fout: + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name0, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state0.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name1, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state1.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name2, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state2.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name3, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state3.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + + # for i, state in enumerate([state0, state1, state2, state3]): + # output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + # os.makedirs(os.path.dirname(output_file), exist_ok=True) + # with open(output_file, 'w') as f: + # save_any_image(state.output, f) + # save_image_file_on_log_server(output_file) + +def generate_b2i_annoy(gen_func, state0, state1, state2, state3, text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3, request: gr.Request): + if not text.strip(): + return (gr.update(visible=False),) * 16 + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + if state2 is None: + state2 = ImageStateIG(model_name2) + if state3 is None: + state3 = ImageStateIG(model_name3) + + + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + model_name2 = "" + model_name3 = "" + + generated_image0, generated_image1, generated_image2, generated_image3, model_name0, model_name1, model_name2, model_name3 \ + = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3) + state0.prompt = text + state1.prompt = text + state2.prompt = text + state3.prompt = text + + state0.output = generated_image0 + state1.output = generated_image1 + state2.output = generated_image2 + state3.output = generated_image3 + + state0.model_name = model_name0 + state1.model_name = model_name1 + state2.model_name = model_name2 + state3.model_name = model_name3 + + yield state0, state1, state2, state3, generated_image0, generated_image1, generated_image2, generated_image3, \ + generated_image0, generated_image1, generated_image2, generated_image3, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False), \ + gr.Markdown(f"### Model C: {model_name2}", visible=False), gr.Markdown(f"### Model D: {model_name3}", visible=False) + + +def generate_igm_cache_annoy(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request): + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + if state2 is None: + state2 = ImageStateIG(model_name2) + if state3 is None: + state3 = ImageStateIG(model_name3) + + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + model_name2 = "" + model_name3 = "" + + generated_image0, generated_image1, generated_image2, generated_image3, model_name0, model_name1, model_name2, model_name3, text \ + = gen_func(model_name0, model_name1, model_name2, model_name3) + state0.prompt = text + state1.prompt = text + state2.prompt = text + state3.prompt = text + + state0.output = generated_image0 + state1.output = generated_image1 + state2.output = generated_image2 + state3.output = generated_image3 + + state0.model_name = model_name0 + state1.model_name = model_name1 + state2.model_name = model_name2 + state3.model_name = model_name3 + + yield state0, state1, state2, state3, generated_image0, generated_image1, generated_image2, generated_image3, \ + generated_image0, generated_image1, generated_image2, generated_image3, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False), \ + gr.Markdown(f"### Model C: {model_name2}", visible=False), gr.Markdown(f"### Model D: {model_name3}", visible=False), text + +def generate_vg_annoy(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request): + + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + if state2 is None: + state2 = ImageStateIG(model_name2) + if state3 is None: + state3 = ImageStateIG(model_name3) + + + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + model_name2 = "" + model_name3 = "" + + generated_video0, generated_video1, generated_video2, generated_video3, model_name0, model_name1, model_name2, model_name3, text, prompt_path \ + = gen_func(model_name0, model_name1, model_name2, model_name3) + state0.prompt = text + state1.prompt = text + state2.prompt = text + state3.prompt = text + + state0.output = generated_video0 + state1.output = generated_video1 + state2.output = generated_video2 + state3.output = generated_video3 + + state0.model_name = model_name0 + state1.model_name = model_name1 + state2.model_name = model_name2 + state3.model_name = model_name3 + + yield state0, state1, state2, state3, generated_video0, generated_video1, generated_video2, generated_video3, \ + generated_video0, generated_video1, generated_video2, generated_video3, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False), \ + gr.Markdown(f"### Model C: {model_name2}", visible=False), gr.Markdown(f"### Model D: {model_name3}", visible=False), \ + text, prompt_path + + +def generate_igm_annoy_museum(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request): + if state0 is None: + state0 = ImageStateIG(model_name0) + if state1 is None: + state1 = ImageStateIG(model_name1) + if state2 is None: + state2 = ImageStateIG(model_name2) + if state3 is None: + state3 = ImageStateIG(model_name3) + + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + model_name2 = re.sub(r"### Model C: ", "", model_name2) + model_name3 = re.sub(r"### Model D: ", "", model_name3) + + generated_image0, generated_image1, generated_image2, generated_image3, model_name0, model_name1, model_name2, model_name3, text \ + = gen_func(model_name0, model_name1, model_name2, model_name3) + + state0.prompt = text + state1.prompt = text + state2.prompt = text + state3.prompt = text + + state0.output = generated_image0 + state1.output = generated_image1 + state2.output = generated_image2 + state3.output = generated_image3 + + state0.model_name = model_name0 + state1.model_name = model_name1 + state2.model_name = model_name2 + state3.model_name = model_name3 + + + yield state0, state1, state2, state3, generated_image0, generated_image1, generated_image2, generated_image3, \ + generated_image0, generated_image1, generated_image2, generated_image3, text, \ + gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}"), \ + gr.Markdown(f"### Model C: {model_name2}"), gr.Markdown(f"### Model D: {model_name3}") + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + # with open(get_conv_log_filename(), "a") as fout: + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name0, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state0.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name1, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state1.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name2, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state2.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + # data = { + # "tstamp": round(finish_tstamp, 4), + # "type": "chat", + # "model": model_name3, + # "gen_params": {}, + # "start": round(start_tstamp, 4), + # "finish": round(finish_tstamp, 4), + # "state": state3.dict(), + # "ip": get_ip(request), + # } + # fout.write(json.dumps(data) + "\n") + # append_json_item_on_log_server(data, get_conv_log_filename()) + + + # for i, state in enumerate([state0, state1, state2, state3]): + # output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' + # os.makedirs(os.path.dirname(output_file), exist_ok=True) + # with open(output_file, 'w') as f: + # save_any_image(state.output, f) + # save_image_file_on_log_server(output_file) + +def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request): + if not source_text: + raise gr.Warning("Source prompt cannot be empty.") + if not target_text: + raise gr.Warning("Target prompt cannot be empty.") + if not instruct_text: + raise gr.Warning("Instruction prompt cannot be empty.") + if not source_image: + raise gr.Warning("Source image cannot be empty.") + if not model_name: + raise gr.Warning("Model name cannot be empty.") + if state is None: + state = ImageStateIE(model_name) + ip = get_ip(request) + ig_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name) + state.source_prompt = source_text + state.target_prompt = target_text + state.instruct_prompt = instruct_text + state.source_image = source_image + state.output = generated_image + state.model_name = model_name + + yield state, generated_image + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + +def generate_ie_museum(gen_func, state, model_name, request: gr.Request): + if not model_name: + raise gr.Warning("Model name cannot be empty.") + if state is None: + state = ImageStateIE(model_name) + ip = get_ip(request) + ig_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name) + state.source_prompt = source_text + state.target_prompt = target_text + state.instruct_prompt = instruct_text + state.source_image = source_image + state.output = generated_image + state.model_name = model_name + + yield state, generated_image, source_image, source_text, target_text, instruct_text + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + + +def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): + if not source_text: + raise gr.Warning("Source prompt cannot be empty.") + if not target_text: + raise gr.Warning("Target prompt cannot be empty.") + if not instruct_text: + raise gr.Warning("Instruction prompt cannot be empty.") + if not source_image: + raise gr.Warning("Source image cannot be empty.") + if not model_name0: + raise gr.Warning("Model name A cannot be empty.") + if not model_name1: + raise gr.Warning("Model name B cannot be empty.") + if state0 is None: + state0 = ImageStateIE(model_name0) + if state1 is None: + state1 = ImageStateIE(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) + state0.source_prompt = source_text + state0.target_prompt = target_text + state0.instruct_prompt = instruct_text + state0.source_image = source_image + state0.output = generated_image0 + state0.model_name = model_name0 + state1.source_prompt = source_text + state1.target_prompt = target_text + state1.instruct_prompt = instruct_text + state1.source_image = source_image + state1.output = generated_image1 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1 + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + +def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): + if not model_name0: + raise gr.Warning("Model name A cannot be empty.") + if not model_name1: + raise gr.Warning("Model name B cannot be empty.") + if state0 is None: + state0 = ImageStateIE(model_name0) + if state1 is None: + state1 = ImageStateIE(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1) + state0.source_prompt = source_text + state0.target_prompt = target_text + state0.instruct_prompt = instruct_text + state0.source_image = source_image + state0.output = generated_image0 + state0.model_name = model_name0 + state1.source_prompt = source_text + state1.target_prompt = target_text + state1.instruct_prompt = instruct_text + state1.source_image = source_image + state1.output = generated_image1 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + + +def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): + if not source_text: + raise gr.Warning("Source prompt cannot be empty.") + if not target_text: + raise gr.Warning("Target prompt cannot be empty.") + if not instruct_text: + raise gr.Warning("Instruction prompt cannot be empty.") + if not source_image: + raise gr.Warning("Source image cannot be empty.") + if state0 is None: + state0 = ImageStateIE(model_name0) + if state1 is None: + state1 = ImageStateIE(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) + state0.source_prompt = source_text + state0.target_prompt = target_text + state0.instruct_prompt = instruct_text + state0.source_image = source_image + state0.output = generated_image0 + state0.model_name = model_name0 + state1.source_prompt = source_text + state1.target_prompt = target_text + state1.instruct_prompt = instruct_text + state1.source_image = source_image + state1.output = generated_image1 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + +def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): + if state0 is None: + state0 = ImageStateIE(model_name0) + if state1 is None: + state1 = ImageStateIE(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1) + state0.source_prompt = source_text + state0.target_prompt = target_text + state0.instruct_prompt = instruct_text + state0.source_image = source_image + state0.output = generated_image0 + state0.model_name = model_name0 + state1.source_prompt = source_text + state1.target_prompt = target_text + state1.instruct_prompt = instruct_text + state1.source_image = source_image + state1.output = generated_image1 + state1.model_name = model_name1 + + yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' + os.makedirs(os.path.dirname(src_img_file), exist_ok=True) + with open(src_img_file, 'w') as f: + save_any_image(state.source_image, f) + output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' + with open(output_file, 'w') as f: + save_any_image(state.output, f) + save_image_file_on_log_server(src_img_file) + save_image_file_on_log_server(output_file) + +def generate_vg(gen_func, state, text, model_name, request: gr.Request): + if not text: + raise gr.Warning("Prompt cannot be empty.") + if not model_name: + raise gr.Warning("Model name cannot be empty.") + if state is None: + state = VideoStateVG(model_name) + ip = get_ip(request) + vg_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + generated_video = gen_func(text, model_name) + state.prompt = text + state.output = generated_video + state.model_name = model_name + + # yield state, generated_video + + finish_tstamp = time.time() + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + if model_name.startswith('fal'): + r = requests.get(state.output) + with open(output_file, 'wb') as outfile: + outfile.write(r.content) + else: + print("======== video shape: ========") + print(state.output.shape) + imageio.mimwrite(output_file, state.output, fps=8, quality=9) + + save_video_file_on_log_server(output_file) + yield state, output_file + +def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): + if not text: + raise gr.Warning("Prompt cannot be empty.") + if not model_name0: + raise gr.Warning("Model name A cannot be empty.") + if not model_name1: + raise gr.Warning("Model name B cannot be empty.") + if state0 is None: + state0 = VideoStateVG(model_name0) + if state1 is None: + state1 = VideoStateVG(model_name1) + ip = get_ip(request) + igm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + # Remove ### Model (A|B): from model name + model_name0 = re.sub(r"### Model A: ", "", model_name0) + model_name1 = re.sub(r"### Model B: ", "", model_name1) + generated_video0, generated_video1 = gen_func(text, model_name0, model_name1) + state0.prompt = text + state1.prompt = text + state0.output = generated_video0 + state1.output = generated_video1 + state0.model_name = model_name0 + state1.model_name = model_name1 + + # yield state0, state1, generated_video0, generated_video1 + print("====== model name =========") + print(state0.model_name) + print(state1.model_name) + + + finish_tstamp = time.time() + + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + print(state.model_name) + + if state.model_name.startswith('fal'): + r = requests.get(state.output) + with open(output_file, 'wb') as outfile: + outfile.write(r.content) + else: + print("======== video shape: ========") + print(state.output) + print(state.output.shape) + imageio.mimwrite(output_file, state.output, fps=8, quality=9) + save_video_file_on_log_server(output_file) + yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4' + + +def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): + if not text: + raise gr.Warning("Prompt cannot be empty.") + if state0 is None: + state0 = VideoStateVG(model_name0) + if state1 is None: + state1 = VideoStateVG(model_name1) + ip = get_ip(request) + vgm_logger.info(f"generate. ip: {ip}") + start_tstamp = time.time() + model_name0 = "" + model_name1 = "" + generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) + state0.prompt = text + state1.prompt = text + state0.output = generated_video0 + state1.output = generated_video1 + state0.model_name = model_name0 + state1.model_name = model_name1 + + # yield state0, state1, generated_video0, generated_video1, \ + # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") + + finish_tstamp = time.time() + # logger.info(f"===output===: {output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name0, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state0.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name1, + "gen_params": {}, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state1.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + append_json_item_on_log_server(data, get_conv_log_filename()) + + for i, state in enumerate([state0, state1]): + output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' + os.makedirs(os.path.dirname(output_file), exist_ok=True) + if state.model_name.startswith('fal'): + r = requests.get(state.output) + with open(output_file, 'wb') as outfile: + outfile.write(r.content) + else: + print("======== video shape: ========") + print(state.output.shape) + imageio.mimwrite(output_file, state.output, fps=8, quality=9) + save_video_file_on_log_server(output_file) + + yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \ + gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) \ No newline at end of file diff --git a/tmp.py b/tmp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d32881401c1953e53015f05c8298f75d182bf96 --- /dev/null +++ b/tmp.py @@ -0,0 +1,5 @@ +def generate_image_b2i(self, prompt, grounding_instruction, bbox, model_name): + if model_name == "local_MIGC_b2i": + from model_bbox.MIGC.inference_single_image import inference_single_image + result = inference_single_image(prompt, grounding_instruction, bbox) + return result \ No newline at end of file