import gradio as gr import os import random import datetime from utils import * from pathlib import Path import gdown pre_generate = False file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip" filename = "derendering_supp.zip" # Cache videos to speed up demo video_cache_dir = Path("./cached_videos") video_cache_dir.mkdir(exist_ok=True) download_file(file_url, filename) unzip_file(filename) print("Downloaded and unzipped the inks.") diagram = get_svg_content("derendering_supp/derender_diagram.svg") org = get_svg_content("org/cor.svg") org_content = f"{org}" gif_filenames = [ "christians.gif", "good.gif", "october.gif", "welcome.gif", "you.gif", "letter.gif", ] captions = [ "CHRISTIANS", "Good", "October", "WELOME", "you", "letter", ] gif_base64_strings = {caption: get_base64_encoded_gif(f"gifs/{name}") for caption, name in zip(captions, gif_filenames)} sketches = [ "bird.gif", "cat.gif", "coffee.gif", "penguin.gif", ] sketches_base64_strings = {name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches} if not pre_generate: # Check if the file already exists if not (video_cache_dir / "gdrive_file.zip").exists(): print("Downloading pre-generated videos from Google Drive.") # Download from Google Drive using gdown gdown.download( "https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA", str(video_cache_dir / "gdrive_file.zip"), quiet=False, ) # Unzip the file to video_cache_dir unzip_file(str(video_cache_dir / "gdrive_file.zip")) else: print("File already exists. Skipping download.") else: pregenerate_videos(video_cache_dir=video_cache_dir) print("Videos cached.") def demo(Dataset, Model): if Model == "Small-i": inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml" elif Model == "Small-p": inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml" elif Model == "Large-i": inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml" now = datetime.datetime.now() random.seed(now.timestamp()) now = now.strftime("%Y-%m-%d %H:%M:%S") print( now, "Taking sample from dataset:", Dataset, "and model:", Model, ) path = f"./derendering_supp/{Dataset}/images_sample" samples = os.listdir(path) # Randomly pick a sample picked_samples = random.sample(samples, min(1, len(samples))) query_modes = ["d+t", "r+d", "vanilla"] plot_title = {"r+d": "Recognized: ", "d+t": "OCR Input: ", "vanilla": ""} text_outputs = [] # img_outputs = [] video_outputs = [] for name in picked_samples: img_path = os.path.join(path, name) img = load_and_pad_img_dir(img_path) for mode in query_modes: example_id = name.strip(".png") inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml") text_field = parse_inkml_annotations(inkml_file)["textField"] output_text = f"{plot_title[mode]}{text_field}" text_outputs.append(output_text) ink = inkml_to_ink(inkml_file) video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" video_filepath = video_cache_dir / video_filename if not video_filepath.exists(): plot_ink_to_video(ink, str(video_filepath), input_image=img) print("Cached video at:", video_filepath) video_outputs.append("./" + str(video_filepath)) # fig, ax = plt.subplots() # ax.axis("off") # plot_ink(ink, ax, input_image=img) # buf = BytesIO() # fig.savefig(buf, format="png", bbox_inches="tight") # plt.close(fig) # buf.seek(0) # res = Image.open(buf) # img_outputs.append(res) return ( img, text_outputs[0], # img_outputs[0], video_outputs[0], text_outputs[1], # img_outputs[1], video_outputs[1], text_outputs[2], # img_outputs[2], video_outputs[2], ) with gr.Blocks() as app: gr.HTML(org_content) gr.Markdown("# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write") gr.HTML( """
Read the Paper View on GitHub Google Research Blog Info
""" ) gr.HTML(f"
{diagram}
") gr.Markdown( """ 🚀 This demo highlights the capabilities of Small-i, Small-p, and Large-i across three public datasets (word-level, with 100 random samples each).
🔔 We've just released the InkSight-Small-p model on Hugging Face! Check it out [here](https://huggingface.co/Derendering/InkSight-Small-p).
🎲 Select a model variant and dataset (IAM, IMGUR5K, HierText), then hit 'Sample' to view a randomly selected input alongside its corresponding outputs for all three types of inference.
""" ) with gr.Row(): dataset = gr.Dropdown(["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM") model = gr.Dropdown( ["Small-i", "Large-i", "Small-p"], label="InkSight Model Variant", value="Small-i", ) im = gr.Image(label="Input Image") # with gr.Row(): # d_t_img = gr.Image(label="Derender with Text") # r_d_img = gr.Image(label="Recognize and Derender") # vanilla_img = gr.Image(label="Vanilla") with gr.Row(): d_t_text = gr.Textbox(label="OCR recognition input to the model", interactive=False) r_d_text = gr.Textbox(label="Recognition from the model", interactive=False) vanilla_text = gr.Textbox(label="Vanilla", interactive=False) with gr.Row(): d_t_vid = gr.Video(label="Derender with Text (Click to stop/play)", autoplay=True) r_d_vid = gr.Video(label="Recognize and Derender (Click to stop/play)", autoplay=True) vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True) with gr.Row(): btn_sub = gr.Button("Sample") btn_sub.click( fn=demo, inputs=[dataset, model], outputs=[ im, d_t_text, # d_t_img, d_t_vid, r_d_text, # r_d_img, r_d_vid, vanilla_text, # vanilla_img, vanilla_vid, ], ) gr.Markdown("## More Word-level Samples") html_content = """
""" for caption, base64_string in gif_base64_strings.items(): title = caption html_content += f"""
{title}

{title}

""" html_content += "
" gr.HTML(html_content) # Sketches gr.Markdown("## Sketch Samples") html_content = """
""" for _, base64_string in sketches_base64_strings.items(): html_content += f"""
""" html_content += "
" gr.HTML(html_content) gr.Markdown("## Scale Up to Full Page") svg1_content = get_svg_content("full_page/danke.svg") svg2_content = get_svg_content("full_page/multilingual_demo.svg") svg3_content = get_svg_content("full_page/unsplash_frame.svg") svg_html_template = """
{}

{}

{}

{}

{}

{}

""" full_svg_display = svg_html_template.format( svg1_content, 'Writings on the beach. Credit', svg2_content, "Multilingual handwriting.", svg3_content, "Handwriting in a frame. Credit", ) gr.HTML(full_svg_display) app.launch()