sabaridsnfuji's picture
Update app.py
40891b3 verified
raw
history blame
5.3 kB
from PIL import Image
# import torch
import os
import gradio as gr
# Below is the code refactored into a Python class for better modularity and reusability.
import torch
from transformers import TextStreamer
class FloorPlanAnalyzer:
def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"):
"""
Initializes the FloorPlanAnalyzer with the specified model and configuration.
"""
from unsloth import FastVisionModel # Assuming unsloth package is installed
self.device = device
self.model, self.tokenizer = FastVisionModel.from_pretrained(
model_path,
load_in_4bit=load_in_4bit,
use_gradient_checkpointing=gradient_checkpointing,
)
FastVisionModel.for_inference(self.model)
def prepare_input(self, image_path, instruction):
"""
Prepares the input for the model by loading the image and applying the chat template.
Args:
image_path (str): Path to the floor plan image.
instruction (str): Instruction text to guide the analysis.
Returns:
torch.Tensor: Processed inputs for the model.
"""
# Load image
image = Image.open(image_path).convert("RGB")
# Create message template
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}
]
# Generate input text
input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Tokenize and prepare inputs
inputs = self.tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to(self.device)
return inputs
def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1):
"""
Analyzes the floor plan based on the provided instruction.
Args:
image_path (str): Path to the floor plan image.
instruction (str): Instruction guiding the analysis.
max_new_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature for generation.
min_p (float): Minimum probability for nucleus sampling.
Returns:
str: The generated output from the model.
"""
# Prepare inputs
inputs = self.prepare_input(image_path, instruction)
# Set up text streamer
text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
# Generate output
output = self.model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=max_new_tokens,
use_cache=True,
temperature=temperature,
min_p=min_p,
)
return output
# Instantiate the FloorPlanAnalyzer
model_path = "./model/"
analyzer = FloorPlanAnalyzer(model_path=model_path)
# Sample images for Gradio examples
# Define sample images for user convenience
sample_images = [
"./sample/10_2.jpg",
"./sample/10_10.jpg",
"./sample/0_10.jpg",
"./sample/2_12.jpg"
]
# Ensure sample images directory exists
os.makedirs("samples", exist_ok=True)
# Save some dummy sample images if they don't exist (you should replace these with actual images)
for i, sample in enumerate(sample_images):
if not os.path.exists(sample):
img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
img.save(sample)
# Gradio prediction function
def predict_image(image, instruction):
"""
Processes the uploaded image and instruction through the FloorPlanAnalyzer.
Args:
image (PIL.Image.Image): The uploaded floor plan image.
instruction (str): The user-provided instruction.
Returns:
str: The generated output description.
"""
return analyzer.analyze(image, instruction)
gr_interface = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(type="pil", label="Upload Floor Plan Image"),
gr.Textbox(
label="Instruction Text",
value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe."""
)
],
outputs=gr.Textbox(label="Analysis Result"),
title="Floor Plan Analyzer",
description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.",
examples=sample_images # Add the examples here
)
# # Gradio UI setup with examples
# gr_interface = gr.Interface(
# fn=predict_image,
# inputs=gr.Image(type="pil"), # Updated to gr.Image for image input
# outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox
# title="House CAD Design Object Detection",
# description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.",
# examples=sample_images # Add the examples here
# )
# Launch the Gradio interface if run as main
if __name__ == "__main__":
gr_interface.launch()