#!/usr/bin/env python from __future__ import annotations import os import random from typing import Tuple, Optional import gradio as gr from huggingface_hub import HfApi from inf import InferencePipeline SAMPLE_MODEL_IDS = [ 'lora-library/B-LoRA-teddybear', 'lora-library/B-LoRA-bull', 'lora-library/B-LoRA-wolf_plushie', 'lora-library/B-LoRA-pen_sketch', 'lora-library/B-LoRA-cartoon_line', 'lora-library/B-LoRA-multi-dog2', ] css = """ body { font-size: 30px; } #title { text-align: center; } #title h1 { font-size: 250%; } .lora-title { text-align: center; border-radius: 10px; background: linear-gradient(90deg, #1CB5E0 0%, #000851 100%); } .gr-image { width: 512px; height: 512px; object-fit: contain; margin: auto; } .lora-column { display: flex; flex-direction: column; align-items: center; justify-content: center; border: none; background: none; } .gr-row { align-items: center; justify-content: center; margin-top: 5px; } """ def get_choices(hf_token): api = HfApi(token=hf_token) choices = [ info.modelId for info in api.list_models(author='lora-library') ] models_list = ['None'] + SAMPLE_MODEL_IDS + choices return models_list def get_image_from_card(card, model_id) -> Optional[str]: try: card_path = f"https://huggingface.co/{model_id}/resolve/main/" widget = card.data.get('widget') if widget is not None or len(widget) > 0: output = widget[0].get('output') if output is not None: url = output.get('url') if url is not None: return card_path + url return None except Exception: return None def demo_init(): try: choices = get_choices(app.hf_token) content_blora = random.choice(SAMPLE_MODEL_IDS) style_blora = random.choice(SAMPLE_MODEL_IDS) content_blora_prompt, content_blora_image = app.load_model_info(content_blora) style_blora_prompt, style_blora_image = app.load_model_info(style_blora) content_lora_model_id = gr.update(choices=choices, value=content_blora) content_prompt = gr.update(value=content_blora_prompt) content_image = gr.update(value=content_blora_image) style_lora_model_id = gr.update(choices=choices, value=style_blora) style_prompt = gr.update(value=style_blora_prompt) style_image = gr.update(value=style_blora_image) prompt = gr.update( value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt except Exception as e: raise type(e)(f'failed to demo_init, due to: {e}') def toggle_column(is_checked): try: return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) except Exception as e: raise type(e)(f'failed to toggle_column, due to: {e}') class InferenceUtil: def __init__(self, hf_token: str | None): self.hf_token = hf_token def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: try: try: card = InferencePipeline.get_model_card(lora_model_id, self.hf_token) except Exception: return '', None instance_prompt = getattr(card.data, 'instance_prompt', '') image_url = get_image_from_card(card, lora_model_id) return instance_prompt, image_url except Exception as e: raise type(e)(f'failed to load_model_info, due to: {e}') def update_model_info(self, model_source: str): try: if model_source == 'None': return '', None else: model_info = self.load_model_info(model_source) new_prompt, new_image = model_info[0], model_info[1] return new_prompt, new_image except Exception as e: raise type(e)(f'failed to update_model_info, due to: {e}') hf_token = os.getenv('HF_TOKEN') pipe = InferencePipeline(hf_token) app = InferenceUtil(hf_token) with gr.Blocks(css=css) as demo: title = gr.HTML( '''
This is a demo for our paper: ''Implicit Style-Content Separation using B-LoRA''.
Project page and code is available here.