|
""" |
|
Usage: |
|
python3 -m unittest tests.test_image_utils |
|
""" |
|
|
|
import base64 |
|
from io import BytesIO |
|
import os |
|
import unittest |
|
|
|
import numpy as np |
|
from PIL import Image |
|
|
|
from fastchat.utils import ( |
|
resize_image_and_return_image_in_bytes, |
|
image_moderation_filter, |
|
) |
|
from fastchat.conversation import get_conv_template |
|
|
|
|
|
def check_byte_size_in_mb(image_base64_str): |
|
return len(image_base64_str) / 1024 / 1024 |
|
|
|
|
|
def generate_random_image(target_size_mb, image_format="PNG"): |
|
|
|
target_size_bytes = target_size_mb * 1024 * 1024 |
|
|
|
|
|
dimension = int((target_size_bytes / 3) ** 0.5) |
|
|
|
|
|
pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8) |
|
|
|
|
|
img = Image.fromarray(pixel_data) |
|
|
|
|
|
temp_filename = "temp_image." + image_format.lower() |
|
img.save(temp_filename, format=image_format) |
|
|
|
|
|
while os.path.getsize(temp_filename) < target_size_bytes: |
|
|
|
dimension += 1 |
|
pixel_data = np.random.randint( |
|
0, 256, (dimension, dimension, 3), dtype=np.uint8 |
|
) |
|
img = Image.fromarray(pixel_data) |
|
img.save(temp_filename, format=image_format) |
|
|
|
return img |
|
|
|
|
|
class DontResizeIfLessThanMaxTest(unittest.TestCase): |
|
def test_dont_resize_if_less_than_max(self): |
|
max_image_size = 5 |
|
initial_size_mb = 0.1 |
|
img = generate_random_image(initial_size_mb) |
|
|
|
image_bytes = BytesIO() |
|
img.save(image_bytes, format="PNG") |
|
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
image_bytes = resize_image_and_return_image_in_bytes( |
|
img, max_image_size_mb=max_image_size |
|
) |
|
new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
|
class ResizeLargeImageForModerationEndpoint(unittest.TestCase): |
|
def test_resize_large_image_and_send_to_moderation_filter(self): |
|
initial_size_mb = 6 |
|
img = generate_random_image(initial_size_mb) |
|
|
|
nsfw_flag, csam_flag = image_moderation_filter(img) |
|
self.assertFalse(nsfw_flag) |
|
self.assertFalse(nsfw_flag) |
|
|
|
|
|
class DontResizeIfMaxImageSizeIsNone(unittest.TestCase): |
|
def test_dont_resize_if_max_image_size_is_none(self): |
|
initial_size_mb = 0.2 |
|
img = generate_random_image(initial_size_mb) |
|
|
|
image_bytes = BytesIO() |
|
img.save(image_bytes, format="PNG") |
|
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
image_bytes = resize_image_and_return_image_in_bytes( |
|
img, max_image_size_mb=None |
|
) |
|
new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
|
class OpenAIConversationDontResizeImage(unittest.TestCase): |
|
def test(self): |
|
conv = get_conv_template("chatgpt") |
|
initial_size_mb = 0.2 |
|
img = generate_random_image(initial_size_mb) |
|
image_bytes = BytesIO() |
|
img.save(image_bytes, format="PNG") |
|
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
resized_img = conv.convert_image_to_base64(img) |
|
resized_img_bytes = base64.b64decode(resized_img) |
|
new_image_size = check_byte_size_in_mb(resized_img_bytes) |
|
|
|
self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
|
class ClaudeConversationResizesCorrectly(unittest.TestCase): |
|
def test(self): |
|
conv = get_conv_template("claude-3-haiku-20240307") |
|
initial_size_mb = 5 |
|
img = generate_random_image(initial_size_mb) |
|
image_bytes = BytesIO() |
|
img.save(image_bytes, format="PNG") |
|
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
|
resized_img = conv.convert_image_to_base64(img) |
|
new_base64_image_size = check_byte_size_in_mb(resized_img) |
|
new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img)) |
|
|
|
self.assertLess(new_image_bytes_size, previous_image_size) |
|
self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb) |
|
self.assertLessEqual(new_base64_image_size, 5) |
|
|