Spaces:
Runtime error
Runtime error
from PIL import Image | |
# import torch | |
import os | |
import gradio as gr | |
# Below is the code refactored into a Python class for better modularity and reusability. | |
import torch | |
from transformers import TextStreamer | |
class FloorPlanAnalyzer: | |
def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"): | |
""" | |
Initializes the FloorPlanAnalyzer with the specified model and configuration. | |
""" | |
from unsloth import FastVisionModel # Assuming unsloth package is installed | |
self.device = device | |
self.model, self.tokenizer = FastVisionModel.from_pretrained( | |
model_path, | |
load_in_4bit=load_in_4bit, | |
use_gradient_checkpointing=gradient_checkpointing, | |
) | |
FastVisionModel.for_inference(self.model) | |
def prepare_input(self, image_path, instruction): | |
""" | |
Prepares the input for the model by loading the image and applying the chat template. | |
Args: | |
image_path (str): Path to the floor plan image. | |
instruction (str): Instruction text to guide the analysis. | |
Returns: | |
torch.Tensor: Processed inputs for the model. | |
""" | |
# Load image | |
image = Image.open(image_path).convert("RGB") | |
# Create message template | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, | |
{"type": "text", "text": instruction} | |
]} | |
] | |
# Generate input text | |
input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
# Tokenize and prepare inputs | |
inputs = self.tokenizer( | |
image, | |
input_text, | |
add_special_tokens=False, | |
return_tensors="pt", | |
).to(self.device) | |
return inputs | |
def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1): | |
""" | |
Analyzes the floor plan based on the provided instruction. | |
Args: | |
image_path (str): Path to the floor plan image. | |
instruction (str): Instruction guiding the analysis. | |
max_new_tokens (int): Maximum number of tokens to generate. | |
temperature (float): Sampling temperature for generation. | |
min_p (float): Minimum probability for nucleus sampling. | |
Returns: | |
str: The generated output from the model. | |
""" | |
# Prepare inputs | |
inputs = self.prepare_input(image_path, instruction) | |
# Set up text streamer | |
text_streamer = TextStreamer(self.tokenizer, skip_prompt=True) | |
# Generate output | |
output = self.model.generate( | |
**inputs, | |
streamer=text_streamer, | |
max_new_tokens=max_new_tokens, | |
use_cache=True, | |
temperature=temperature, | |
min_p=min_p, | |
) | |
return output | |
# Instantiate the FloorPlanAnalyzer | |
model_path = "./model/" | |
analyzer = FloorPlanAnalyzer(model_path=model_path) | |
# Sample images for Gradio examples | |
# Define sample images for user convenience | |
sample_images = [ | |
"./sample/10_2.jpg", | |
"./sample/10_10.jpg", | |
"./sample/0_10.jpg", | |
"./sample/2_12.jpg" | |
] | |
# Ensure sample images directory exists | |
os.makedirs("samples", exist_ok=True) | |
# Save some dummy sample images if they don't exist (you should replace these with actual images) | |
for i, sample in enumerate(sample_images): | |
if not os.path.exists(sample): | |
img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50)) | |
img.save(sample) | |
# Gradio prediction function | |
def predict_image(image, instruction): | |
""" | |
Processes the uploaded image and instruction through the FloorPlanAnalyzer. | |
Args: | |
image (PIL.Image.Image): The uploaded floor plan image. | |
instruction (str): The user-provided instruction. | |
Returns: | |
str: The generated output description. | |
""" | |
return analyzer.analyze(image, instruction) | |
gr_interface = gr.Interface( | |
fn=predict_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Floor Plan Image"), | |
gr.Textbox( | |
label="Instruction Text", | |
value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe.""" | |
) | |
], | |
outputs=gr.Textbox(label="Analysis Result"), | |
title="Floor Plan Analyzer", | |
description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.", | |
examples=sample_images # Add the examples here | |
) | |
# # Gradio UI setup with examples | |
# gr_interface = gr.Interface( | |
# fn=predict_image, | |
# inputs=gr.Image(type="pil"), # Updated to gr.Image for image input | |
# outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox | |
# title="House CAD Design Object Detection", | |
# description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.", | |
# examples=sample_images # Add the examples here | |
# ) | |
# Launch the Gradio interface if run as main | |
if __name__ == "__main__": | |
gr_interface.launch() | |