sabaridsnfuji commited on
Commit
40891b3
1 Parent(s): a6bf3ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -162
app.py CHANGED
@@ -1,162 +1,162 @@
1
-
2
-
3
- from PIL import Image
4
- # import torch
5
- import os
6
- import gradio as gr
7
-
8
- # Below is the code refactored into a Python class for better modularity and reusability.
9
-
10
-
11
-
12
- from transformers import TextStreamer
13
-
14
- class FloorPlanAnalyzer:
15
- def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"):
16
- """
17
- Initializes the FloorPlanAnalyzer with the specified model and configuration.
18
- """
19
- from unsloth import FastVisionModel # Assuming unsloth package is installed
20
-
21
- self.device = device
22
- self.model, self.tokenizer = FastVisionModel.from_pretrained(
23
- model_path,
24
- load_in_4bit=load_in_4bit,
25
- use_gradient_checkpointing=gradient_checkpointing,
26
- )
27
- FastVisionModel.for_inference(self.model)
28
-
29
- def prepare_input(self, image_path, instruction):
30
- """
31
- Prepares the input for the model by loading the image and applying the chat template.
32
-
33
- Args:
34
- image_path (str): Path to the floor plan image.
35
- instruction (str): Instruction text to guide the analysis.
36
-
37
- Returns:
38
- torch.Tensor: Processed inputs for the model.
39
- """
40
- # Load image
41
- image = Image.open(image_path).convert("RGB")
42
-
43
- # Create message template
44
- messages = [
45
- {"role": "user", "content": [
46
- {"type": "image"},
47
- {"type": "text", "text": instruction}
48
- ]}
49
- ]
50
-
51
- # Generate input text
52
- input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
53
-
54
- # Tokenize and prepare inputs
55
- inputs = self.tokenizer(
56
- image,
57
- input_text,
58
- add_special_tokens=False,
59
- return_tensors="pt",
60
- ).to(self.device)
61
-
62
- return inputs
63
-
64
- def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1):
65
- """
66
- Analyzes the floor plan based on the provided instruction.
67
-
68
- Args:
69
- image_path (str): Path to the floor plan image.
70
- instruction (str): Instruction guiding the analysis.
71
- max_new_tokens (int): Maximum number of tokens to generate.
72
- temperature (float): Sampling temperature for generation.
73
- min_p (float): Minimum probability for nucleus sampling.
74
-
75
- Returns:
76
- str: The generated output from the model.
77
- """
78
- # Prepare inputs
79
- inputs = self.prepare_input(image_path, instruction)
80
-
81
- # Set up text streamer
82
- text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
83
-
84
- # Generate output
85
- output = self.model.generate(
86
- **inputs,
87
- streamer=text_streamer,
88
- max_new_tokens=max_new_tokens,
89
- use_cache=True,
90
- temperature=temperature,
91
- min_p=min_p,
92
- )
93
-
94
- return output
95
-
96
-
97
-
98
-
99
- # Instantiate the FloorPlanAnalyzer
100
- model_path = "./model/"
101
- analyzer = FloorPlanAnalyzer(model_path=model_path)
102
-
103
-
104
- # Sample images for Gradio examples
105
- # Define sample images for user convenience
106
- sample_images = [
107
- "./sample/10_2.jpg",
108
- "./sample/10_10.jpg",
109
- "./sample/0_10.jpg",
110
- "./sample/2_12.jpg"
111
- ]
112
-
113
- # Ensure sample images directory exists
114
- os.makedirs("samples", exist_ok=True)
115
- # Save some dummy sample images if they don't exist (you should replace these with actual images)
116
- for i, sample in enumerate(sample_images):
117
- if not os.path.exists(sample):
118
- img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
119
- img.save(sample)
120
-
121
- # Gradio prediction function
122
- def predict_image(image, instruction):
123
- """
124
- Processes the uploaded image and instruction through the FloorPlanAnalyzer.
125
-
126
- Args:
127
- image (PIL.Image.Image): The uploaded floor plan image.
128
- instruction (str): The user-provided instruction.
129
-
130
- Returns:
131
- str: The generated output description.
132
- """
133
- return analyzer.analyze(image, instruction)
134
-
135
- gr_interface = gr.Interface(
136
- fn=predict_image,
137
- inputs=[
138
- gr.Image(type="pil", label="Upload Floor Plan Image"),
139
- gr.Textbox(
140
- label="Instruction Text",
141
- value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe."""
142
- )
143
- ],
144
- outputs=gr.Textbox(label="Analysis Result"),
145
- title="Floor Plan Analyzer",
146
- description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.",
147
- examples=sample_images # Add the examples here
148
- )
149
-
150
- # # Gradio UI setup with examples
151
- # gr_interface = gr.Interface(
152
- # fn=predict_image,
153
- # inputs=gr.Image(type="pil"), # Updated to gr.Image for image input
154
- # outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox
155
- # title="House CAD Design Object Detection",
156
- # description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.",
157
- # examples=sample_images # Add the examples here
158
- # )
159
-
160
- # Launch the Gradio interface if run as main
161
- if __name__ == "__main__":
162
- gr_interface.launch()
 
1
+
2
+
3
+ from PIL import Image
4
+ # import torch
5
+ import os
6
+ import gradio as gr
7
+
8
+ # Below is the code refactored into a Python class for better modularity and reusability.
9
+
10
+
11
+ import torch
12
+ from transformers import TextStreamer
13
+
14
+ class FloorPlanAnalyzer:
15
+ def __init__(self, model_path, load_in_4bit=True, gradient_checkpointing="unsloth", device="cpu"):
16
+ """
17
+ Initializes the FloorPlanAnalyzer with the specified model and configuration.
18
+ """
19
+ from unsloth import FastVisionModel # Assuming unsloth package is installed
20
+
21
+ self.device = device
22
+ self.model, self.tokenizer = FastVisionModel.from_pretrained(
23
+ model_path,
24
+ load_in_4bit=load_in_4bit,
25
+ use_gradient_checkpointing=gradient_checkpointing,
26
+ )
27
+ FastVisionModel.for_inference(self.model)
28
+
29
+ def prepare_input(self, image_path, instruction):
30
+ """
31
+ Prepares the input for the model by loading the image and applying the chat template.
32
+
33
+ Args:
34
+ image_path (str): Path to the floor plan image.
35
+ instruction (str): Instruction text to guide the analysis.
36
+
37
+ Returns:
38
+ torch.Tensor: Processed inputs for the model.
39
+ """
40
+ # Load image
41
+ image = Image.open(image_path).convert("RGB")
42
+
43
+ # Create message template
44
+ messages = [
45
+ {"role": "user", "content": [
46
+ {"type": "image"},
47
+ {"type": "text", "text": instruction}
48
+ ]}
49
+ ]
50
+
51
+ # Generate input text
52
+ input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
53
+
54
+ # Tokenize and prepare inputs
55
+ inputs = self.tokenizer(
56
+ image,
57
+ input_text,
58
+ add_special_tokens=False,
59
+ return_tensors="pt",
60
+ ).to(self.device)
61
+
62
+ return inputs
63
+
64
+ def analyze(self, image_path, instruction, max_new_tokens=512, temperature=1.5, min_p=0.1):
65
+ """
66
+ Analyzes the floor plan based on the provided instruction.
67
+
68
+ Args:
69
+ image_path (str): Path to the floor plan image.
70
+ instruction (str): Instruction guiding the analysis.
71
+ max_new_tokens (int): Maximum number of tokens to generate.
72
+ temperature (float): Sampling temperature for generation.
73
+ min_p (float): Minimum probability for nucleus sampling.
74
+
75
+ Returns:
76
+ str: The generated output from the model.
77
+ """
78
+ # Prepare inputs
79
+ inputs = self.prepare_input(image_path, instruction)
80
+
81
+ # Set up text streamer
82
+ text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
83
+
84
+ # Generate output
85
+ output = self.model.generate(
86
+ **inputs,
87
+ streamer=text_streamer,
88
+ max_new_tokens=max_new_tokens,
89
+ use_cache=True,
90
+ temperature=temperature,
91
+ min_p=min_p,
92
+ )
93
+
94
+ return output
95
+
96
+
97
+
98
+
99
+ # Instantiate the FloorPlanAnalyzer
100
+ model_path = "./model/"
101
+ analyzer = FloorPlanAnalyzer(model_path=model_path)
102
+
103
+
104
+ # Sample images for Gradio examples
105
+ # Define sample images for user convenience
106
+ sample_images = [
107
+ "./sample/10_2.jpg",
108
+ "./sample/10_10.jpg",
109
+ "./sample/0_10.jpg",
110
+ "./sample/2_12.jpg"
111
+ ]
112
+
113
+ # Ensure sample images directory exists
114
+ os.makedirs("samples", exist_ok=True)
115
+ # Save some dummy sample images if they don't exist (you should replace these with actual images)
116
+ for i, sample in enumerate(sample_images):
117
+ if not os.path.exists(sample):
118
+ img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
119
+ img.save(sample)
120
+
121
+ # Gradio prediction function
122
+ def predict_image(image, instruction):
123
+ """
124
+ Processes the uploaded image and instruction through the FloorPlanAnalyzer.
125
+
126
+ Args:
127
+ image (PIL.Image.Image): The uploaded floor plan image.
128
+ instruction (str): The user-provided instruction.
129
+
130
+ Returns:
131
+ str: The generated output description.
132
+ """
133
+ return analyzer.analyze(image, instruction)
134
+
135
+ gr_interface = gr.Interface(
136
+ fn=predict_image,
137
+ inputs=[
138
+ gr.Image(type="pil", label="Upload Floor Plan Image"),
139
+ gr.Textbox(
140
+ label="Instruction Text",
141
+ value="""You are an expert in architecture and interior design. Analyze the floor plan image and describe accurately the key features, room count, layout, and any other important details you observe."""
142
+ )
143
+ ],
144
+ outputs=gr.Textbox(label="Analysis Result"),
145
+ title="Floor Plan Analyzer",
146
+ description="Upload a floor plan image and provide instructions to analyze it. Get detailed insights into the layout and design.",
147
+ examples=sample_images # Add the examples here
148
+ )
149
+
150
+ # # Gradio UI setup with examples
151
+ # gr_interface = gr.Interface(
152
+ # fn=predict_image,
153
+ # inputs=gr.Image(type="pil"), # Updated to gr.Image for image input
154
+ # outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox
155
+ # title="House CAD Design Object Detection",
156
+ # description="Upload a CAD design image of a house to detect objects with bounding boxes and probabilities.",
157
+ # examples=sample_images # Add the examples here
158
+ # )
159
+
160
+ # Launch the Gradio interface if run as main
161
+ if __name__ == "__main__":
162
+ gr_interface.launch()