from PIL import Image # import torch import os import gradio as gr # Below is the code refactored into a Python class for better modularity and reusability. import torch from transformers import TextStreamer class FloorPlanAnalyzer: def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"): """ Initializes the FloorPlanAnalyzer with the specified model and configuration. """ from unsloth import FastVisionModel # Assuming unsloth package is installed self.device = device self.model, self.tokenizer = FastVisionModel.from_pretrained( model_path, load_in_4bit=load_in_4bit, use_gradient_checkpointing=gradient_checkpointing, ) FastVisionModel.for_inference(self.model) def prepare_input(self, image_path, instruction): """ Prepares the input for the model by loading the image and applying the chat template. Args: image_path (str): Path to the floor plan image. instruction (str): Instruction text to guide the analysis. Returns: torch.Tensor: Processed inputs for the model. """ # Load image image = Image.open(image_path).convert("RGB") # Create message template messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": instruction} ]} ] # Generate input text input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Tokenize and prepare inputs inputs = self.tokenizer( image, input_text, add_special_tokens=False, return_tensors="pt", ).to(self.device) return inputs def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1): """ Analyzes the floor plan based on the provided instruction. Args: image_path (str): Path to the floor plan image. instruction (str): Instruction guiding the analysis. max_new_tokens (int): Maximum number of tokens to generate. temperature (float): Sampling temperature for generation. min_p (float): Minimum probability for nucleus sampling. Returns: str: The generated output from the model. """ # Prepare inputs inputs = self.prepare_input(image_path, instruction) # Set up text streamer text_streamer = TextStreamer(self.tokenizer, skip_prompt=True) # Generate output output = self.model.generate( **inputs, streamer=text_streamer, max_new_tokens=max_new_tokens, use_cache=True, temperature=temperature, min_p=min_p, ) return output # Instantiate the FloorPlanAnalyzer model_path = "./model/" analyzer = FloorPlanAnalyzer(model_path=model_path) # Sample images for Gradio examples # Define sample images for user convenience sample_images = [ "./sample/10_2.jpg", "./sample/10_10.jpg", "./sample/0_10.jpg", "./sample/2_12.jpg" ] # Ensure sample images directory exists os.makedirs("samples", exist_ok=True) # Save some dummy sample images if they don't exist (you should replace these with actual images) for i, sample in enumerate(sample_images): if not os.path.exists(sample): img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50)) img.save(sample) # Gradio prediction function def predict_image(image, instruction): """ Processes the uploaded image and instruction through the FloorPlanAnalyzer. Args: image (PIL.Image.Image): The uploaded floor plan image. instruction (str): The user-provided instruction. Returns: str: The generated output description. """ return analyzer.analyze(image, instruction) gr_interface = gr.Interface( fn=predict_image, inputs=[ gr.Image(type="pil", label="Upload Floor Plan Image"), gr.Textbox( label="Instruction Text", value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe.""" ) ], outputs=gr.Textbox(label="Analysis Result"), title="Floor Plan Analyzer", description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.", examples=sample_images # Add the examples here ) # # Gradio UI setup with examples # gr_interface = gr.Interface( # fn=predict_image, # inputs=gr.Image(type="pil"), # Updated to gr.Image for image input # outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox # title="House CAD Design Object Detection", # description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.", # examples=sample_images # Add the examples here # ) # Launch the Gradio interface if run as main if __name__ == "__main__": gr_interface.launch()