File size: 4,374 Bytes
ffead1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gc
import unittest

from parameterized import parameterized

from diffusers import FlaxUNet2DConditionModel
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow


if is_flax_available():
    import jax
    import jax.numpy as jnp


@slow
@require_flax
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
    def get_file_format(self, seed, shape):
        return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()

    def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
        dtype = jnp.bfloat16 if fp16 else jnp.float32
        image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
        return image

    def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
        dtype = jnp.bfloat16 if fp16 else jnp.float32
        revision = "bf16" if fp16 else None

        model, params = FlaxUNet2DConditionModel.from_pretrained(
            model_id, subfolder="unet", dtype=dtype, revision=revision
        )
        return model, params

    def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
        dtype = jnp.bfloat16 if fp16 else jnp.float32
        hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
        return hidden_states

    @parameterized.expand(
        [
            # fmt: off
            [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
            [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
            [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
            [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
            # fmt: on
        ]
    )
    def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
        model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
        latents = self.get_latents(seed, fp16=True)
        encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)

        sample = model.apply(
            {"params": params},
            latents,
            jnp.array(timestep, dtype=jnp.int32),
            encoder_hidden_states=encoder_hidden_states,
        ).sample

        assert sample.shape == latents.shape

        output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
        expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)

        # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
        assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)

    @parameterized.expand(
        [
            # fmt: off
            [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
            [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
            [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
            [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
            # fmt: on
        ]
    )
    def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
        model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
        latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
        encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)

        sample = model.apply(
            {"params": params},
            latents,
            jnp.array(timestep, dtype=jnp.int32),
            encoder_hidden_states=encoder_hidden_states,
        ).sample

        assert sample.shape == latents.shape

        output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
        expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)

        # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
        assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)