File size: 5,300 Bytes
40891b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163


from PIL import Image
# import torch
import os
import gradio as gr

# Below is the code refactored into a Python class for better modularity and reusability.


import torch
from transformers import TextStreamer

class FloorPlanAnalyzer:
    def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"):
        """
        Initializes the FloorPlanAnalyzer with the specified model and configuration.
        """
        from unsloth import FastVisionModel  # Assuming unsloth package is installed

        self.device = device
        self.model, self.tokenizer = FastVisionModel.from_pretrained(
            model_path,
            load_in_4bit=load_in_4bit,
            use_gradient_checkpointing=gradient_checkpointing,
        )
        FastVisionModel.for_inference(self.model)

    def prepare_input(self, image_path, instruction):
        """
        Prepares the input for the model by loading the image and applying the chat template.

        Args:
            image_path (str): Path to the floor plan image.
            instruction (str): Instruction text to guide the analysis.

        Returns:
            torch.Tensor: Processed inputs for the model.
        """
        # Load image
        image = Image.open(image_path).convert("RGB")

        # Create message template
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]

        # Generate input text
        input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)

        # Tokenize and prepare inputs
        inputs = self.tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to(self.device)

        return inputs

    def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1):
        """
        Analyzes the floor plan based on the provided instruction.

        Args:
            image_path (str): Path to the floor plan image.
            instruction (str): Instruction guiding the analysis.
            max_new_tokens (int): Maximum number of tokens to generate.
            temperature (float): Sampling temperature for generation.
            min_p (float): Minimum probability for nucleus sampling.

        Returns:
            str: The generated output from the model.
        """
        # Prepare inputs
        inputs = self.prepare_input(image_path, instruction)

        # Set up text streamer
        text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)

        # Generate output
        output = self.model.generate(
            **inputs,
            streamer=text_streamer,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            temperature=temperature,
            min_p=min_p,
        )

        return output




# Instantiate the FloorPlanAnalyzer
model_path = "./model/"
analyzer = FloorPlanAnalyzer(model_path=model_path)


# Sample images for Gradio examples
# Define sample images for user convenience
sample_images = [
    "./sample/10_2.jpg",
    "./sample/10_10.jpg",
    "./sample/0_10.jpg",
    "./sample/2_12.jpg"
]

# Ensure sample images directory exists
os.makedirs("samples", exist_ok=True)
# Save some dummy sample images if they don't exist (you should replace these with actual images)
for i, sample in enumerate(sample_images):
    if not os.path.exists(sample):
        img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
        img.save(sample)

# Gradio prediction function
def predict_image(image, instruction):
    """
    Processes the uploaded image and instruction through the FloorPlanAnalyzer.

    Args:
        image (PIL.Image.Image): The uploaded floor plan image.
        instruction (str): The user-provided instruction.

    Returns:
        str: The generated output description.
    """
    return analyzer.analyze(image, instruction)

gr_interface = gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Image(type="pil", label="Upload Floor Plan Image"),
        gr.Textbox(
            label="Instruction Text",
            value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe."""
        )
    ],
    outputs=gr.Textbox(label="Analysis Result"),
    title="Floor Plan Analyzer",
    description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.",
    examples=sample_images  # Add the examples here
)

# # Gradio UI setup with examples
# gr_interface = gr.Interface(
#     fn=predict_image, 
#     inputs=gr.Image(type="pil"),  # Updated to gr.Image for image input
#     outputs=[gr.Image(type="pil"), gr.Textbox()],  # Updated to gr.Image and gr.Textbox
#     title="House CAD Design Object Detection", 
#     description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.",
#     examples=sample_images  # Add the examples here
# )

# Launch the Gradio interface if run as main
if __name__ == "__main__":
    gr_interface.launch()