hatmanstack commited on
Commit
059f429
1 Parent(s): cb5a183

refactor with rate limits and troll check

Browse files
Files changed (4) hide show
  1. app.py +4 -4
  2. functions.py +49 -23
  3. generate.py +7 -132
  4. processImage.py +134 -0
app.py CHANGED
@@ -147,15 +147,15 @@ with gr.Blocks() as demo:
147
  with gr.Column():
148
  gr.Markdown("""
149
  <div style="text-align: center;">
150
- Generate an image using a color palette. If you choose to include an image (optional) the subject and style will be used as a reference.
151
  The colors of the image will also be incorporated, along with the colors from the colors list. A color list is always required but one has been provided.
152
  </div>
153
  """)
154
  reference_image = gr.Image(type='pil', label="Reference Image")
155
  colors = gr.Textbox(label="Colors", placeholder="Enter up to 10 colors as hex values, e.g., #00FF00,#FCF2AB", max_lines=1)
156
- with gr.Accordion("Optional Prompt", open=False):
157
- prompt = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
158
- gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
159
  error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
160
  output = gr.Image()
161
  with gr.Accordion("Advanced Options", open=False):
 
147
  with gr.Column():
148
  gr.Markdown("""
149
  <div style="text-align: center;">
150
+ Generate an image using a color palette. If you must include an image and text prompt, the subject and style will be used as a reference.
151
  The colors of the image will also be incorporated, along with the colors from the colors list. A color list is always required but one has been provided.
152
  </div>
153
  """)
154
  reference_image = gr.Image(type='pil', label="Reference Image")
155
  colors = gr.Textbox(label="Colors", placeholder="Enter up to 10 colors as hex values, e.g., #00FF00,#FCF2AB", max_lines=1)
156
+
157
+ prompt = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
158
+ gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
159
  error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
160
  output = gr.Image()
161
  with gr.Accordion("Advanced Options", open=False):
functions.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  from PIL import Image
6
  from generate import *
7
  from typing import Dict, Any
 
8
 
9
  def display_image(image_bytes):
10
  if isinstance(image_bytes, str):
@@ -56,6 +57,12 @@ def build_request(task_type, params, height=1024, width=1024, quality="standard"
56
  )
57
  })
58
 
 
 
 
 
 
 
59
 
60
  def text_to_image(prompt, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
61
  text_to_image_params = {"text": prompt,
@@ -63,15 +70,16 @@ def text_to_image(prompt, negative_text=None, height=1024, width=1024, quality="
63
  }
64
 
65
  body = build_request("TEXT_IMAGE", text_to_image_params, height, width, quality, cfg_scale, seed)
66
- image_bytes = generate_image(body)
67
- return display_image(image_bytes)
 
68
 
69
  def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
70
  images = process_images(primary=image, secondary=None)
71
 
72
  for value in images.values():
73
- if isinstance(value, str) and "Not Appropriate" in value:
74
- return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
75
  # Prepare the inPaintingParams dictionary
76
  if mask_prompt and mask_image:
77
  raise ValueError("You must specify either maskPrompt or maskImage, but not both.")
@@ -87,13 +95,15 @@ def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_tex
87
  }
88
 
89
  body = build_request("INPAINTING", in_painting_params, height, width, quality, cfg_scale, seed)
90
- return display_image(generate_image(body))
 
 
91
 
92
  def outpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, outpainting_mode="DEFAULT", height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
93
  images = process_images(primary=image, secondary=None)
94
  for value in images.values():
95
- if isinstance(value, str) and "Not Appropriate" in value:
96
- return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
97
 
98
  if mask_prompt and mask_image:
99
  raise ValueError("You must specify either maskPrompt or maskImage, but not both.")
@@ -111,13 +121,19 @@ def outpainting(image, mask_prompt=None, mask_image=None, text=None, negative_te
111
  }
112
 
113
  body = build_request("OUTPAINTING", out_painting_params, height, width, quality, cfg_scale, seed)
114
- return display_image(generate_image(body))
 
 
115
 
116
  def image_variation(images, text=None, negative_text=None, similarity_strength=0.5, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
117
  encoded_images = []
118
  for image_path in images:
119
  with open(image_path, "rb") as image_file:
120
- encoded_images.append(process_and_encode_image(image_file))
 
 
 
 
121
 
122
  # Prepare the imageVariationParams dictionary
123
  image_variation_params = {
@@ -127,30 +143,34 @@ def image_variation(images, text=None, negative_text=None, similarity_strength=0
127
  }
128
 
129
  body = build_request("IMAGE_VARIATION", image_variation_params, height, width, quality, cfg_scale, seed)
130
- return display_image(generate_image(body))
 
 
131
 
132
  def image_conditioning(condition_image, text, negative_text=None, control_mode="CANNY_EDGE", control_strength=0.7, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
133
  condition_image_encoded = process_images(primary=condition_image)
134
  for value in condition_image_encoded.values():
135
- if isinstance(value, str) and "Not Appropriate" in value:
136
- return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
137
  # Prepare the textToImageParams dictionary
138
  text_to_image_params = {
139
  "text": text,
140
  "controlMode": control_mode,
141
  "controlStrength": control_strength,
142
- **condition_image_encoded,
143
  **({"negativeText": negative_text} if negative_text not in [None, ""] else {})
144
  }
145
  body = build_request("TEXT_IMAGE", text_to_image_params, height, width, quality, cfg_scale, seed)
146
- return display_image(generate_image(body))
 
 
147
 
148
  def color_guided_content(text=None, reference_image=None, negative_text=None, colors=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
149
  # Encode the reference image if provided
150
  reference_image_encoded = process_images(primary=reference_image)
151
  for value in reference_image_encoded.values():
152
- if isinstance(value, str) and "Not Appropriate" in value:
153
- return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
154
 
155
  if not colors:
156
  colors = "#FF5733,#33FF57,#3357FF,#FF33A1,#33FFF5,#FF8C33,#8C33FF,#33FF8C,#FF3333,#33A1FF"
@@ -158,24 +178,30 @@ def color_guided_content(text=None, reference_image=None, negative_text=None, co
158
  color_guided_generation_params = {
159
  "text": text,
160
  "colors": colors.split(','),
161
- **reference_image_encoded,
162
  **({"negativeText": negative_text} if negative_text not in [None, ""] else {})
163
  }
164
 
165
  body = build_request("COLOR_GUIDED_GENERATION", color_guided_generation_params, height, width, quality, cfg_scale, seed)
166
- return display_image(generate_image(body))
 
 
167
 
168
  def background_removal(image):
169
- input_image = process_and_encode_image(image)
170
  for value in input_image.values():
171
- if isinstance(value, str) and "Not Appropriate" in value:
172
- return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
173
 
174
  body = json.dumps({
175
  "taskType": "BACKGROUND_REMOVAL",
176
- "backgroundRemovalParams": {"image": input_image}
 
 
177
  })
178
- return display_image(generate_image(body))
 
 
179
 
180
  def generate_nova_prompt():
181
 
 
5
  from PIL import Image
6
  from generate import *
7
  from typing import Dict, Any
8
+ from processImage import process_and_encode_image
9
 
10
  def display_image(image_bytes):
11
  if isinstance(image_bytes, str):
 
57
  )
58
  })
59
 
60
+ def check_return(result):
61
+ if not isinstance(result, bytes):
62
+ return None, gr.update(visible=True, value=result)
63
+
64
+ return Image.open(io.BytesIO(result)), gr.update(visible=False)
65
+
66
 
67
  def text_to_image(prompt, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
68
  text_to_image_params = {"text": prompt,
 
70
  }
71
 
72
  body = build_request("TEXT_IMAGE", text_to_image_params, height, width, quality, cfg_scale, seed)
73
+ result = generate_image(body)
74
+ return check_return(result)
75
+
76
 
77
  def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
78
  images = process_images(primary=image, secondary=None)
79
 
80
  for value in images.values():
81
+ if len(value) < 200:
82
+ return None, gr.update(visible=True, value=value)
83
  # Prepare the inPaintingParams dictionary
84
  if mask_prompt and mask_image:
85
  raise ValueError("You must specify either maskPrompt or maskImage, but not both.")
 
95
  }
96
 
97
  body = build_request("INPAINTING", in_painting_params, height, width, quality, cfg_scale, seed)
98
+ result = generate_image(body)
99
+
100
+ return check_return(result)
101
 
102
  def outpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, outpainting_mode="DEFAULT", height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
103
  images = process_images(primary=image, secondary=None)
104
  for value in images.values():
105
+ if len(value) < 200:
106
+ return None, gr.update(visible=True, value=value)
107
 
108
  if mask_prompt and mask_image:
109
  raise ValueError("You must specify either maskPrompt or maskImage, but not both.")
 
121
  }
122
 
123
  body = build_request("OUTPAINTING", out_painting_params, height, width, quality, cfg_scale, seed)
124
+ result = generate_image(body)
125
+
126
+ return check_return(result)
127
 
128
  def image_variation(images, text=None, negative_text=None, similarity_strength=0.5, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
129
  encoded_images = []
130
  for image_path in images:
131
  with open(image_path, "rb") as image_file:
132
+ value = process_and_encode_image(image_file)
133
+
134
+ if len(value) < 200:
135
+ return None, gr.update(visible=True, value=value)
136
+ encoded_images.append(value)
137
 
138
  # Prepare the imageVariationParams dictionary
139
  image_variation_params = {
 
143
  }
144
 
145
  body = build_request("IMAGE_VARIATION", image_variation_params, height, width, quality, cfg_scale, seed)
146
+ result = generate_image(body)
147
+
148
+ return check_return(result)
149
 
150
  def image_conditioning(condition_image, text, negative_text=None, control_mode="CANNY_EDGE", control_strength=0.7, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
151
  condition_image_encoded = process_images(primary=condition_image)
152
  for value in condition_image_encoded.values():
153
+ if len(value) < 200:
154
+ return None, gr.update(visible=True, value=value)
155
  # Prepare the textToImageParams dictionary
156
  text_to_image_params = {
157
  "text": text,
158
  "controlMode": control_mode,
159
  "controlStrength": control_strength,
160
+ "conditionImage": condition_image_encoded.get('image'),
161
  **({"negativeText": negative_text} if negative_text not in [None, ""] else {})
162
  }
163
  body = build_request("TEXT_IMAGE", text_to_image_params, height, width, quality, cfg_scale, seed)
164
+ result = generate_image(body)
165
+
166
+ return check_return(result)
167
 
168
  def color_guided_content(text=None, reference_image=None, negative_text=None, colors=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
169
  # Encode the reference image if provided
170
  reference_image_encoded = process_images(primary=reference_image)
171
  for value in reference_image_encoded.values():
172
+ if len(value) < 200:
173
+ return None, gr.update(visible=True, value=value)
174
 
175
  if not colors:
176
  colors = "#FF5733,#33FF57,#3357FF,#FF33A1,#33FFF5,#FF8C33,#8C33FF,#33FF8C,#FF3333,#33A1FF"
 
178
  color_guided_generation_params = {
179
  "text": text,
180
  "colors": colors.split(','),
181
+ "referenceImage": reference_image_encoded.get('image'),
182
  **({"negativeText": negative_text} if negative_text not in [None, ""] else {})
183
  }
184
 
185
  body = build_request("COLOR_GUIDED_GENERATION", color_guided_generation_params, height, width, quality, cfg_scale, seed)
186
+ result = generate_image(body)
187
+
188
+ return check_return(result)
189
 
190
  def background_removal(image):
191
+ input_image = process_images(primary=image)
192
  for value in input_image.values():
193
+ if len(value) < 200:
194
+ return None, gr.update(visible=True, value=value)
195
 
196
  body = json.dumps({
197
  "taskType": "BACKGROUND_REMOVAL",
198
+ "backgroundRemovalParams": {
199
+ "image": input_image.get('image')
200
+ }
201
  })
202
+ result = generate_image(body)
203
+
204
+ return check_return(result)
205
 
206
  def generate_nova_prompt():
207
 
generate.py CHANGED
@@ -3,14 +3,9 @@ import base64
3
  import boto3
4
  import json
5
  import logging
6
- import io
7
- import time
8
- import requests
9
  from datetime import datetime
10
  from dotenv import load_dotenv
11
- from PIL import Image
12
  from functools import wraps
13
- from dataclasses import dataclass
14
  from botocore.config import Config
15
  from botocore.exceptions import ClientError
16
 
@@ -35,121 +30,13 @@ def handle_bedrock_errors(func):
35
  raise ImageError(f"Unexpected error: {str(err)}")
36
  return wrapper
37
 
38
- @dataclass
39
- class ImageConfig:
40
- min_size: int = 320
41
- max_size: int = 4096
42
- max_pixels: int = 4194304
43
- quality: str = "standard"
44
- format: str = "PNG"
45
-
46
- config = ImageConfig()
47
-
48
- model_id = 'amazon.nova-canvas-v1:0'
49
  aws_id = os.getenv('AWS_ID')
50
  aws_secret = os.getenv('AWS_SECRET')
51
- token = os.environ.get("HF_TOKEN")
52
- headers = {"Authorization": f"Bearer {token}", "x-use-cache": "0", 'Content-Type': 'application/json'}
53
  nova_image_bucket='nova-image-data'
54
  bucket_region='us-west-2'
55
-
56
- class ImageProcessor:
57
- def __init__(self, image):
58
- self.image = self._open_image(image)
59
-
60
- def _open_image(self, image):
61
- """Convert input to PIL Image if necessary."""
62
- if image is None:
63
- raise ValueError("Input image is required.")
64
- return Image.open(image) if not isinstance(image, Image.Image) else image
65
-
66
- def _check_nsfw(self, attempts=1):
67
- """Check if image is NSFW using Hugging Face API."""
68
- API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
69
-
70
- # Prepare image data
71
- temp_buffer = io.BytesIO()
72
- self.image.save(temp_buffer, format='PNG')
73
- temp_buffer.seek(0)
74
-
75
- try:
76
- response = requests.request("POST", API_URL, headers=headers, data=temp_buffer.getvalue())
77
- json_response = json.loads(response.content.decode("utf-8"))
78
- print(json_response)
79
- if "error" in json_response:
80
- if attempts > 30:
81
- raise ImageError("NSFW check failed after multiple attempts")
82
- time.sleep(json_response["estimated_time"])
83
- return self._check_nsfw(attempts + 1)
84
-
85
- nsfw_score = next((item['score'] for item in json_response if item['label'] == 'nsfw'), 0)
86
- print(f"NSFW Score: {nsfw_score}")
87
-
88
- if nsfw_score > 0.1:
89
- return None
90
-
91
- return self
92
-
93
- except json.JSONDecodeError as e:
94
- raise ImageError(f"NSFW check failed: Invalid response format - {str(e)}")
95
- except Exception as e:
96
- if attempts > 30:
97
- raise ImageError("NSFW check failed after multiple attempts")
98
- return self._check_nsfw(attempts + 1)
99
-
100
- def _convert_color_mode(self):
101
- """Handle color mode conversion."""
102
- if self.image.mode not in ('RGB', 'RGBA'):
103
- self.image = self.image.convert('RGB')
104
- elif self.image.mode == 'RGBA':
105
- background = Image.new('RGB', self.image.size, (255, 255, 255))
106
- background.paste(self.image, mask=self.image.split()[3])
107
- self.image = background
108
- return self
109
-
110
- def _resize_for_pixels(self, max_pixels):
111
- """Resize image to meet pixel limit."""
112
- current_pixels = self.image.width * self.image.height
113
- if current_pixels > max_pixels:
114
- aspect_ratio = self.image.width / self.image.height
115
- if aspect_ratio > 1:
116
- new_width = int((max_pixels * aspect_ratio) ** 0.5)
117
- new_height = int(new_width / aspect_ratio)
118
- else:
119
- new_height = int((max_pixels / aspect_ratio) ** 0.5)
120
- new_width = int(new_height * aspect_ratio)
121
- self.image = self.image.resize((new_width, new_height), Image.LANCZOS)
122
- return self
123
-
124
- def _ensure_dimensions(self, min_size=320, max_size=4096):
125
- if (self.image.width < min_size or
126
- self.image.width > max_size or
127
- self.image.height < min_size or
128
- self.image.height > max_size):
129
-
130
- new_width = min(max(self.image.width, min_size), max_size)
131
- new_height = min(max(self.image.height, min_size), max_size)
132
- self.image = self.image.resize((new_width, new_height), Image.LANCZOS)
133
-
134
- return self
135
-
136
- def encode(self):
137
- image_bytes = io.BytesIO()
138
- self.image.save(image_bytes, format='PNG', optimize=True)
139
- return base64.b64encode(image_bytes.getvalue()).decode('utf8')
140
-
141
- def process(self, min_size=320, max_size=4096, max_pixels=4194304):
142
- """Process image with all necessary transformations."""
143
- result = (self
144
- ._convert_color_mode()
145
- ._resize_for_pixels(max_pixels)
146
- ._ensure_dimensions(min_size, max_size)
147
- ._check_nsfw()) # Add NSFW check before encoding
148
-
149
- if result is None:
150
- raise ImageError("Image <b>Not Appropriate</b>")
151
-
152
- return result.encode()
153
 
154
  # Function to generate an image using Amazon Nova Canvas model
155
  class BedrockClient:
@@ -282,16 +169,12 @@ def check_rate_limit(body):
282
 
283
  # Check limits based on quality
284
  if quality == 'premium':
285
- if len(rate_data['premium']) >= 4:
286
- raise ImageError("""<div style='text-align: center;'>Premium rate limit exceeded. Check back later, use the
287
- <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
288
- try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>""")
289
  rate_data['premium'].append(current_time)
290
  else: # standard
291
- if len(rate_data['standard']) >= 8:
292
- raise ImageError("""<div style='text-align: center;'>Standard rate limit exceeded. Check back later, use the
293
- <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
294
- try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>""")
295
  rate_data['standard'].append(current_time)
296
 
297
  # Update rate limit file
@@ -303,14 +186,6 @@ def check_rate_limit(body):
303
  )
304
 
305
 
306
- def process_and_encode_image(image, **kwargs):
307
- """Process and encode image with default parameters."""
308
- try:
309
- image = ImageProcessor(image).process(**kwargs)
310
- return image
311
- except ImageError as e:
312
- return str(e)
313
-
314
  def generate_image(body):
315
  """Generate image using Bedrock service."""
316
  try:
 
3
  import boto3
4
  import json
5
  import logging
 
 
 
6
  from datetime import datetime
7
  from dotenv import load_dotenv
 
8
  from functools import wraps
 
9
  from botocore.config import Config
10
  from botocore.exceptions import ClientError
11
 
 
30
  raise ImageError(f"Unexpected error: {str(err)}")
31
  return wrapper
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  aws_id = os.getenv('AWS_ID')
34
  aws_secret = os.getenv('AWS_SECRET')
 
 
35
  nova_image_bucket='nova-image-data'
36
  bucket_region='us-west-2'
37
+ rate_limit_message = """<div style='text-align: center;'>{} rate limit exceeded. Check back later, use the
38
+ <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
39
+ try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Function to generate an image using Amazon Nova Canvas model
42
  class BedrockClient:
 
169
 
170
  # Check limits based on quality
171
  if quality == 'premium':
172
+ if len(rate_data['premium']) >= 3:
173
+ raise ImageError(rate_limit_message.format('Premium'))
 
 
174
  rate_data['premium'].append(current_time)
175
  else: # standard
176
+ if len(rate_data['standard']) >= 6:
177
+ raise ImageError(rate_limit_message.format('Standard'))
 
 
178
  rate_data['standard'].append(current_time)
179
 
180
  # Update rate limit file
 
186
  )
187
 
188
 
 
 
 
 
 
 
 
 
189
  def generate_image(body):
190
  """Generate image using Bedrock service."""
191
  try:
processImage.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import json
4
+ import io
5
+ import time
6
+ import requests
7
+ from dotenv import load_dotenv
8
+ from PIL import Image
9
+ from dataclasses import dataclass
10
+
11
+ load_dotenv()
12
+ # Move custom exceptions to the top
13
+ class ImageError(Exception):
14
+ def __init__(self, message):
15
+ self.message = message
16
+
17
+ @dataclass
18
+ class ImageConfig:
19
+ min_size: int = 320
20
+ max_size: int = 4096
21
+ max_pixels: int = 4194304
22
+ quality: str = "standard"
23
+ format: str = "PNG"
24
+
25
+ config = ImageConfig()
26
+
27
+ token = os.environ.get("HF_TOKEN")
28
+ headers = {"Authorization": f"Bearer {token}", "x-use-cache": "0", 'Content-Type': 'application/json'}
29
+
30
+ class ImageProcessor:
31
+ def __init__(self, image):
32
+ self.image = self._open_image(image)
33
+
34
+ def _open_image(self, image):
35
+ """Convert input to PIL Image if necessary."""
36
+ if image is None:
37
+ raise ValueError("Input image is required.")
38
+ return Image.open(image) if not isinstance(image, Image.Image) else image
39
+
40
+ def _check_nsfw(self, attempts=1):
41
+ """Check if image is NSFW using Hugging Face API."""
42
+ API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
43
+
44
+ # Prepare image data
45
+ temp_buffer = io.BytesIO()
46
+ self.image.save(temp_buffer, format='PNG')
47
+ temp_buffer.seek(0)
48
+
49
+ try:
50
+ response = requests.request("POST", API_URL, headers=headers, data=temp_buffer.getvalue())
51
+ json_response = json.loads(response.content.decode("utf-8"))
52
+ print(json_response)
53
+ if "error" in json_response:
54
+ if attempts > 30:
55
+ raise ImageError("NSFW check failed after multiple attempts")
56
+ time.sleep(json_response["estimated_time"])
57
+ return self._check_nsfw(attempts + 1)
58
+
59
+ nsfw_score = next((item['score'] for item in json_response if item['label'] == 'nsfw'), 0)
60
+ print(f"NSFW Score: {nsfw_score}")
61
+
62
+ if nsfw_score > 0.1:
63
+ return None
64
+
65
+ return self
66
+
67
+ except json.JSONDecodeError as e:
68
+ raise ImageError(f"NSFW check failed: Invalid response format - {str(e)}")
69
+ except Exception as e:
70
+ if attempts > 30:
71
+ raise ImageError("NSFW check failed after multiple attempts")
72
+ return self._check_nsfw(attempts + 1)
73
+
74
+ def _convert_color_mode(self):
75
+ """Handle color mode conversion."""
76
+ if self.image.mode not in ('RGB', 'RGBA'):
77
+ self.image = self.image.convert('RGB')
78
+ elif self.image.mode == 'RGBA':
79
+ background = Image.new('RGB', self.image.size, (255, 255, 255))
80
+ background.paste(self.image, mask=self.image.split()[3])
81
+ self.image = background
82
+ return self
83
+
84
+ def _resize_for_pixels(self, max_pixels):
85
+ """Resize image to meet pixel limit."""
86
+ current_pixels = self.image.width * self.image.height
87
+ if current_pixels > max_pixels:
88
+ aspect_ratio = self.image.width / self.image.height
89
+ if aspect_ratio > 1:
90
+ new_width = int((max_pixels * aspect_ratio) ** 0.5)
91
+ new_height = int(new_width / aspect_ratio)
92
+ else:
93
+ new_height = int((max_pixels / aspect_ratio) ** 0.5)
94
+ new_width = int(new_height * aspect_ratio)
95
+ self.image = self.image.resize((new_width, new_height), Image.LANCZOS)
96
+ return self
97
+
98
+ def _ensure_dimensions(self, min_size=320, max_size=4096):
99
+ if (self.image.width < min_size or
100
+ self.image.width > max_size or
101
+ self.image.height < min_size or
102
+ self.image.height > max_size):
103
+
104
+ new_width = min(max(self.image.width, min_size), max_size)
105
+ new_height = min(max(self.image.height, min_size), max_size)
106
+ self.image = self.image.resize((new_width, new_height), Image.LANCZOS)
107
+
108
+ return self
109
+
110
+ def encode(self):
111
+ image_bytes = io.BytesIO()
112
+ self.image.save(image_bytes, format='PNG', optimize=True)
113
+ return base64.b64encode(image_bytes.getvalue()).decode('utf8')
114
+
115
+ def process(self, min_size=320, max_size=4096, max_pixels=4194304):
116
+ """Process image with all necessary transformations."""
117
+ result = (self
118
+ ._convert_color_mode()
119
+ ._resize_for_pixels(max_pixels)
120
+ ._ensure_dimensions(min_size, max_size)
121
+ ._check_nsfw()) # Add NSFW check before encoding
122
+
123
+ if result is None:
124
+ raise ImageError("Image <b>Not Appropriate</b>")
125
+
126
+ return result.encode()
127
+
128
+ def process_and_encode_image(image, **kwargs):
129
+ """Process and encode image with default parameters."""
130
+ try:
131
+ image = ImageProcessor(image).process(**kwargs)
132
+ return image
133
+ except ImageError as e:
134
+ return str(e)