import os import gradio as gr import torch from PIL import Image from transformers import MllamaForConditionalGeneration, AutoProcessor from peft import PeftModel from huggingface_hub import login import spaces import json import matplotlib.pyplot as plt import io import base64 def check_environment(): required_vars = ["HF_TOKEN"] missing_vars = [var for var in required_vars if var not in os.environ] if missing_vars: raise ValueError( f"Missing required environment variables: {', '.join(missing_vars)}\n" "Please set the HF_TOKEN environment variable with your Hugging Face token" ) # Login to Hugging Face check_environment() login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) # Load model and processor (do this outside the inference function to avoid reloading) base_model_path = ( "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel" ) # lora_weights_path = "taesiri/BugsBunny-LLama-3.2-11B-Vision-Base-Medium-LoRA" processor = AutoProcessor.from_pretrained(base_model_path) model = MllamaForConditionalGeneration.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", ) # model = PeftModel.from_pretrained(model, lora_weights_path) model.tie_weights() def describe_image_in_JSON(json_string): try: # First JSON decode first_decode = json.loads(json_string) # Second JSON decode - parse the actual data final_data = json.loads(first_decode) return final_data except json.JSONDecodeError as e: return f"Error parsing JSON: {str(e)}" def create_color_palette_image(colors): if not colors or not isinstance(colors, list): return None try: # Validate color format for color in colors: if not isinstance(color, str) or not color.startswith("#"): return None # Create figure and axis fig, ax = plt.subplots(figsize=(10, 2)) # Create rectangles for each color for i, color in enumerate(colors): ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) # Set the view limits and aspect ratio ax.set_xlim(0, len(colors)) ax.set_ylim(0, 1) ax.set_xticks([]) ax.set_yticks([]) return fig # Return the matplotlib figure directly except Exception as e: print(f"Error creating color palette: {e}") return None @spaces.GPU def inference(image): if image is None: return ["Please provide an image"] * 8 if not isinstance(image, Image.Image): try: image = Image.fromarray(image) except Exception as e: print(f"Image conversion error: {e}") return ["Invalid image format"] * 8 # Prepare input messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Describe the image in JSON"}, ], } ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) try: # Move inputs to the correct device inputs = processor( image, input_text, add_special_tokens=False, return_tensors="pt" ).to(model.device) # Clear CUDA cache after inference with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=2048) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Inference error: {e}") return ["Error during inference"] * 8 # Decode output result = processor.decode(output[0], skip_special_tokens=True) print("DEBUG: Full decoded output:", result) try: json_str = result.strip().split("assistant\n")[1].strip() print("DEBUG: Extracted JSON string after split:", json_str) except Exception as e: print("DEBUG: Error splitting response:", e) return ["Error extracting JSON from response"] * 8 + [ "Failed to extract JSON", "Error", ] parsed_json = describe_image_in_JSON(json_str) if parsed_json: # Create color palette visualization colors = parsed_json.get("color_palette", []) color_image = create_color_palette_image(colors) # Convert lists to proper format for Gradio JSON components character_list = json.dumps(parsed_json.get("character_list", [])) object_list = json.dumps(parsed_json.get("object_list", [])) texture_details = json.dumps(parsed_json.get("texture_details", [])) return ( parsed_json.get("description", "Not available"), parsed_json.get("scene_description", "Not available"), character_list, object_list, texture_details, parsed_json.get("lighting_details", "Not available"), color_image, json_str, "", # Error box "Analysis complete", # Status ) return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"] # Update Gradio interface with gr.Blocks() as demo: gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Image", elem_id="large-image", ) submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Tabs(): with gr.Tab("Structured Results"): with gr.Column(scale=1): description_output = gr.Textbox( label="Description", lines=4, ) scene_output = gr.Textbox( label="Scene Description", lines=2, ) characters_output = gr.JSON( label="Characters", ) objects_output = gr.JSON( label="Objects", ) textures_output = gr.JSON( label="Texture Details", ) lighting_output = gr.Textbox( label="Lighting Details", lines=2, ) color_palette_output = gr.Plot( label="Color Palette", ) with gr.Tab("Raw Output"): raw_output = gr.Textbox( label="Raw JSON Response", lines=25, max_lines=30, ) error_box = gr.Textbox(label="Error Messages", visible=False) with gr.Row(): status_text = gr.Textbox(label="Status", value="Ready", interactive=False) submit_btn.click( fn=inference, inputs=[image_input], outputs=[ description_output, scene_output, characters_output, objects_output, textures_output, lighting_output, color_palette_output, raw_output, error_box, status_text, ], api_name="analyze", ) demo.launch(share=True)