File size: 4,294 Bytes
14c8ffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import base64
from io import BytesIO
import os
import time

import numpy as np
from PIL import Image, ImageChops
import pytest
import requests


def local_run(model_endpoint: str, model_input: dict):
    # Maximum wait time in seconds
    max_wait_time = 1000
    # Interval between status checks in seconds
    retry_interval = 100

    total_wait_time = 0
    while total_wait_time < max_wait_time:
        response = requests.post(model_endpoint, json={"input": model_input})
        data = response.json()

        if "output" in data:
            try:
                datauri = data["output"][0]
                base64_encoded_data = datauri.split(",")[1]
                decoded_data = base64.b64decode(base64_encoded_data)
                return Image.open(BytesIO(decoded_data))
            except Exception as e:
                print("Error while processing output:")
                print("input:", model_input)
                print(data)
                raise e
        elif "detail" in data and data["detail"] == "Already running a prediction":
            print(f"Prediction in progress, waited {total_wait_time}s, waiting more...")
            time.sleep(retry_interval)
            total_wait_time += retry_interval
        else:
            print("Unexpected response data:", data)
            break
    else:
        raise Exception("Max wait time exceeded, unable to get valid response")


def image_equal_fuzzy(img_expected, img_actual, test_name="default", tol=20):
    """
    Assert that average pixel values differ by less than tol across image
    Tol determined empirically - holding everything else equal but varying seed
    generates images that vary by at least 50
    """
    img1 = np.array(img_expected, dtype=np.int32)
    img2 = np.array(img_actual, dtype=np.int32)

    mean_delta = np.mean(np.abs(img1 - img2))
    imgs_equal = mean_delta < tol
    if not imgs_equal:
        # save failures for quick inspection
        save_dir = f"/tmp/{test_name}"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        img_expected.save(os.path.join(save_dir, "expected.png"))
        img_actual.save(os.path.join(save_dir, "actual.png"))
        difference = ImageChops.difference(img_expected, img_actual)
        difference.save(os.path.join(save_dir, "delta.png"))

    return imgs_equal


@pytest.fixture
def expected_image():
    return Image.open("tests/assets/out.png")


def test_seeded_prediction(expected_image):
    data = {
        "image": "https://replicate.delivery/pbxt/KIIutO7jIleskKaWebhvurgBUlHR6M6KN7KHaMMWSt4OnVrF/musk_resize.jpeg",
        "prompt": "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality",
        "scheduler": "EulerDiscreteScheduler",
        "enable_lcm": False,
        "pose_image": "https://replicate.delivery/pbxt/KJmFdQRQVDXGDVdVXftLvFrrvgOPXXRXbzIVEyExPYYOFPyF/80048a6e6586759dbcb529e74a9042ca.jpeg",
        "sdxl_weights": "protovision-xl-high-fidel",
        "pose_strength": 0.4,
        "canny_strength": 0.3,
        "depth_strength": 0.5,
        "guidance_scale": 5,
        "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured",
        "ip_adapter_scale": 0.8,
        "lcm_guidance_scale": 1.5,
        "num_inference_steps": 30,
        "enable_pose_controlnet": True,
        "enhance_nonface_region": True,
        "enable_canny_controlnet": False,
        "enable_depth_controlnet": False,
        "lcm_num_inference_steps": 5,
        "controlnet_conditioning_scale": 0.8,
        "seed": 1337,
    }

    actual_image = local_run("http://localhost:5000/predictions", data)
    expected_image = Image.open("tests/assets/out.png")
    test_result = image_equal_fuzzy(
        actual_image, expected_image, test_name="test_seeded_prediction"
    )
    if test_result:
        print("Test passed successfully.")
    else:
        print("Test failed.")
    assert test_result