import base64 from io import BytesIO import os import time import numpy as np from PIL import Image, ImageChops import pytest import requests def local_run(model_endpoint: str, model_input: dict): # Maximum wait time in seconds max_wait_time = 1000 # Interval between status checks in seconds retry_interval = 100 total_wait_time = 0 while total_wait_time < max_wait_time: response = requests.post(model_endpoint, json={"input": model_input}) data = response.json() if "output" in data: try: datauri = data["output"][0] base64_encoded_data = datauri.split(",")[1] decoded_data = base64.b64decode(base64_encoded_data) return Image.open(BytesIO(decoded_data)) except Exception as e: print("Error while processing output:") print("input:", model_input) print(data) raise e elif "detail" in data and data["detail"] == "Already running a prediction": print(f"Prediction in progress, waited {total_wait_time}s, waiting more...") time.sleep(retry_interval) total_wait_time += retry_interval else: print("Unexpected response data:", data) break else: raise Exception("Max wait time exceeded, unable to get valid response") def image_equal_fuzzy(img_expected, img_actual, test_name="default", tol=20): """ Assert that average pixel values differ by less than tol across image Tol determined empirically - holding everything else equal but varying seed generates images that vary by at least 50 """ img1 = np.array(img_expected, dtype=np.int32) img2 = np.array(img_actual, dtype=np.int32) mean_delta = np.mean(np.abs(img1 - img2)) imgs_equal = mean_delta < tol if not imgs_equal: # save failures for quick inspection save_dir = f"/tmp/{test_name}" if not os.path.exists(save_dir): os.makedirs(save_dir) img_expected.save(os.path.join(save_dir, "expected.png")) img_actual.save(os.path.join(save_dir, "actual.png")) difference = ImageChops.difference(img_expected, img_actual) difference.save(os.path.join(save_dir, "delta.png")) return imgs_equal @pytest.fixture def expected_image(): return Image.open("tests/assets/out.png") def test_seeded_prediction(expected_image): data = { "image": "https://replicate.delivery/pbxt/KIIutO7jIleskKaWebhvurgBUlHR6M6KN7KHaMMWSt4OnVrF/musk_resize.jpeg", "prompt": "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality", "scheduler": "EulerDiscreteScheduler", "enable_lcm": False, "pose_image": "https://replicate.delivery/pbxt/KJmFdQRQVDXGDVdVXftLvFrrvgOPXXRXbzIVEyExPYYOFPyF/80048a6e6586759dbcb529e74a9042ca.jpeg", "sdxl_weights": "protovision-xl-high-fidel", "pose_strength": 0.4, "canny_strength": 0.3, "depth_strength": 0.5, "guidance_scale": 5, "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured", "ip_adapter_scale": 0.8, "lcm_guidance_scale": 1.5, "num_inference_steps": 30, "enable_pose_controlnet": True, "enhance_nonface_region": True, "enable_canny_controlnet": False, "enable_depth_controlnet": False, "lcm_num_inference_steps": 5, "controlnet_conditioning_scale": 0.8, "seed": 1337, } actual_image = local_run("http://localhost:5000/predictions", data) expected_image = Image.open("tests/assets/out.png") test_result = image_equal_fuzzy( actual_image, expected_image, test_name="test_seeded_prediction" ) if test_result: print("Test passed successfully.") else: print("Test failed.") assert test_result