mar / app.py
jadechoghari's picture
Update app.py
a49ecc2 verified
import gradio as gr
from diffusers import DiffusionPipeline
import os
import torch
import shutil
import spaces
def find_cuda():
# Check if CUDA_HOME or CUDA_PATH environment variables are set
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home and os.path.exists(cuda_home):
return cuda_home
# Search for the nvcc executable in the system's PATH
nvcc_path = shutil.which('nvcc')
if nvcc_path:
# Remove the 'bin/nvcc' part to get the CUDA installation path
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
return cuda_path
return None
cuda_path = find_cuda()
if cuda_path:
print(f"CUDA installation found at: {cuda_path}")
else:
print("CUDA installation not found")
# check if cuda is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the pipeline/model
pipeline = DiffusionPipeline.from_pretrained("jadechoghari/mar", trust_remote_code=True, custom_pipeline="jadechoghari/mar")
# function that generates images
@spaces.GPU
def generate_image(seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule):
generated_image = pipeline(
model_type="mar_huge", # using mar_huge
seed=seed,
num_ar_steps=num_ar_steps,
class_labels=[int(label.strip()) for label in class_labels.split(',')],
cfg_scale=cfg_scale,
cfg_schedule=cfg_schedule,
output_dir="./images"
)
return generated_image
with gr.Blocks() as demo:
gr.Markdown("""
# MAR Image Generation Demo πŸš€
Welcome to the demo for **MAR** (Masked Autoregressive Model), a novel approach to image generation that eliminates the need for vector quantization. MAR uses a diffusion process to generate images in a continuous-valued space, resulting in faster, more efficient, and higher-quality outputs.
Simply adjust the parameters below to create your custom images in real-time.
Make sure to provide valid **ImageNet class labels** to see the translation of text to image. For a complete list of ImageNet classes, check out [this reference](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/).
For more details, visit the [GitHub repository](https://github.com/LTH14/mar).
""")
seed = gr.Number(value=0, label="Seed")
num_ar_steps = gr.Slider(minimum=1, maximum=256, value=64, label="Number of AR Steps")
class_labels = gr.Textbox(value="207, 360, 388, 113, 355, 980, 323, 979", label="Class Labels (comma-separated ImageNet labels)")
cfg_scale = gr.Slider(minimum=1, maximum=10, value=4, label="CFG Scale")
cfg_schedule = gr.Dropdown(choices=["constant", "linear"], label="CFG Schedule", value="constant")
image_output = gr.Image(label="Generated Image")
generate_button = gr.Button("Generate Image")
# we link the button to the function and display the output
generate_button.click(generate_image, inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule], outputs=image_output)
# Add example inputs and cache them for quicker demos
gr.Interface(
generate_image,
inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule],
outputs=image_output,
# examples=[
# [0, 64, "207, 360, 388, 113, 355, 980, 323, 979", 4, "constant"],
# ],
# cache_examples=True
)
demo.launch()