Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import unittest | |
from parameterized import parameterized | |
from diffusers import FlaxUNet2DConditionModel | |
from diffusers.utils import is_flax_available | |
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow | |
if is_flax_available(): | |
import jax | |
import jax.numpy as jnp | |
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): | |
def get_file_format(self, seed, shape): | |
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" | |
def tearDown(self): | |
# clean up the VRAM after each test | |
super().tearDown() | |
gc.collect() | |
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): | |
dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) | |
return image | |
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): | |
dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
revision = "bf16" if fp16 else None | |
model, params = FlaxUNet2DConditionModel.from_pretrained( | |
model_id, subfolder="unet", dtype=dtype, revision=revision | |
) | |
return model, params | |
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): | |
dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) | |
return hidden_states | |
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): | |
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) | |
latents = self.get_latents(seed, fp16=True) | |
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) | |
sample = model.apply( | |
{"params": params}, | |
latents, | |
jnp.array(timestep, dtype=jnp.int32), | |
encoder_hidden_states=encoder_hidden_states, | |
).sample | |
assert sample.shape == latents.shape | |
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) | |
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) | |
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware | |
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) | |
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): | |
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) | |
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) | |
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) | |
sample = model.apply( | |
{"params": params}, | |
latents, | |
jnp.array(timestep, dtype=jnp.int32), | |
encoder_hidden_states=encoder_hidden_states, | |
).sample | |
assert sample.shape == latents.shape | |
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) | |
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) | |
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware | |
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) | |