import gradio as gr from easygui import msgbox import subprocess from .common_gui import get_folder_path, add_pre_postfix import os from library.custom_logging import setup_logging # Set up logging log = setup_logging() def caption_images( train_data_dir, caption_extension, batch_size, general_threshold, character_threshold, replace_underscores, model, recursive, max_data_loader_n_workers, debug, undesired_tags, frequency_tags, prefix, postfix, ): # Check for images_dir_input if train_data_dir == '': msgbox('Image folder is missing...') return if caption_extension == '': msgbox('Please provide an extension for the caption files.') return log.info(f'Captioning files in {train_data_dir}...') run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' run_cmd += f' --batch_size={int(batch_size)}' run_cmd += f' --general_threshold={general_threshold}' run_cmd += f' --character_threshold={character_threshold}' run_cmd += f' --caption_extension="{caption_extension}"' run_cmd += f' --model="{model}"' run_cmd += ( f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' ) if recursive: run_cmd += f' --recursive' if debug: run_cmd += f' --debug' if replace_underscores: run_cmd += f' --remove_underscore' if frequency_tags: run_cmd += f' --frequency_tags' if not undesired_tags == '': run_cmd += f' --undesired_tags="{undesired_tags}"' run_cmd += f' "{train_data_dir}"' log.info(run_cmd) # Run the command if os.name == 'posix': os.system(run_cmd) else: subprocess.run(run_cmd) # Add prefix and postfix add_pre_postfix( folder=train_data_dir, caption_file_ext=caption_extension, prefix=prefix, postfix=postfix, ) log.info('...captioning done') ### # Gradio UI ### def gradio_wd14_caption_gui_tab(headless=False): with gr.Tab('WD14 Captioning'): gr.Markdown( 'This utility will use WD14 to caption files for each images in a folder.' ) # Input Settings # with gr.Section('Input Settings'): with gr.Row(): train_data_dir = gr.Textbox( label='Image folder to caption', placeholder='Directory containing the images to caption', interactive=True, ) button_train_data_dir_input = gr.Button( '📂', elem_id='open_folder_small', visible=(not headless) ) button_train_data_dir_input.click( get_folder_path, outputs=train_data_dir, show_progress=False, ) caption_extension = gr.Textbox( label='Caption file extension', placeholder='Extention for caption file. eg: .caption, .txt', value='.txt', interactive=True, ) undesired_tags = gr.Textbox( label='Undesired tags', placeholder='(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.', interactive=True, ) with gr.Row(): prefix = gr.Textbox( label='Prefix to add to WD14 caption', placeholder='(Optional)', interactive=True, ) postfix = gr.Textbox( label='Postfix to add to WD14 caption', placeholder='(Optional)', interactive=True, ) with gr.Row(): replace_underscores = gr.Checkbox( label='Replace underscores in filenames with spaces', value=True, interactive=True, ) recursive = gr.Checkbox( label='Recursive', value=False, info='Tag subfolders images as well', ) debug = gr.Checkbox( label='Verbose logging', value=True, info='Debug while tagging, it will print your image file with general tags and character tags.', ) frequency_tags = gr.Checkbox( label='Show tags frequency', value=True, info='Show frequency of tags for images.', ) # Model Settings with gr.Row(): model = gr.Dropdown( label='Model', choices=[ 'SmilingWolf/wd-v1-4-convnext-tagger-v2', 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2', 'SmilingWolf/wd-v1-4-vit-tagger-v2', 'SmilingWolf/wd-v1-4-swinv2-tagger-v2', ], value='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', ) general_threshold = gr.Slider( value=0.35, label='General threshold', info='Adjust `general_threshold` for pruning tags (less tags, less flexible)', minimum=0, maximum=1, step=0.05, ) character_threshold = gr.Slider( value=0.35, label='Character threshold', info='useful if you want to train with character', minimum=0, maximum=1, step=0.05, ) # Advanced Settings with gr.Row(): batch_size = gr.Number( value=8, label='Batch size', interactive=True ) max_data_loader_n_workers = gr.Number( value=2, label='Max dataloader workers', interactive=True ) caption_button = gr.Button('Caption images') caption_button.click( caption_images, inputs=[ train_data_dir, caption_extension, batch_size, general_threshold, character_threshold, replace_underscores, model, recursive, max_data_loader_n_workers, debug, undesired_tags, frequency_tags, prefix, postfix, ], show_progress=False, )