Spaces:
Running
on
Zero
Running
on
Zero
import unittest | |
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible | |
class IsSafetensorsCompatibleTests(unittest.TestCase): | |
def test_all_is_compatible(self): | |
filenames = [ | |
"safety_checker/pytorch_model.bin", | |
"safety_checker/model.safetensors", | |
"vae/diffusion_pytorch_model.bin", | |
"vae/diffusion_pytorch_model.safetensors", | |
"text_encoder/pytorch_model.bin", | |
"text_encoder/model.safetensors", | |
"unet/diffusion_pytorch_model.bin", | |
"unet/diffusion_pytorch_model.safetensors", | |
] | |
self.assertTrue(is_safetensors_compatible(filenames)) | |
def test_diffusers_model_is_compatible(self): | |
filenames = [ | |
"unet/diffusion_pytorch_model.bin", | |
"unet/diffusion_pytorch_model.safetensors", | |
] | |
self.assertTrue(is_safetensors_compatible(filenames)) | |
def test_diffusers_model_is_not_compatible(self): | |
filenames = [ | |
"safety_checker/pytorch_model.bin", | |
"safety_checker/model.safetensors", | |
"vae/diffusion_pytorch_model.bin", | |
"vae/diffusion_pytorch_model.safetensors", | |
"text_encoder/pytorch_model.bin", | |
"text_encoder/model.safetensors", | |
"unet/diffusion_pytorch_model.bin", | |
# Removed: 'unet/diffusion_pytorch_model.safetensors', | |
] | |
self.assertFalse(is_safetensors_compatible(filenames)) | |
def test_transformer_model_is_compatible(self): | |
filenames = [ | |
"text_encoder/pytorch_model.bin", | |
"text_encoder/model.safetensors", | |
] | |
self.assertTrue(is_safetensors_compatible(filenames)) | |
def test_transformer_model_is_not_compatible(self): | |
filenames = [ | |
"safety_checker/pytorch_model.bin", | |
"safety_checker/model.safetensors", | |
"vae/diffusion_pytorch_model.bin", | |
"vae/diffusion_pytorch_model.safetensors", | |
"text_encoder/pytorch_model.bin", | |
# Removed: 'text_encoder/model.safetensors', | |
"unet/diffusion_pytorch_model.bin", | |
"unet/diffusion_pytorch_model.safetensors", | |
] | |
self.assertFalse(is_safetensors_compatible(filenames)) | |
def test_all_is_compatible_variant(self): | |
filenames = [ | |
"safety_checker/pytorch_model.fp16.bin", | |
"safety_checker/model.fp16.safetensors", | |
"vae/diffusion_pytorch_model.fp16.bin", | |
"vae/diffusion_pytorch_model.fp16.safetensors", | |
"text_encoder/pytorch_model.fp16.bin", | |
"text_encoder/model.fp16.safetensors", | |
"unet/diffusion_pytorch_model.fp16.bin", | |
"unet/diffusion_pytorch_model.fp16.safetensors", | |
] | |
variant = "fp16" | |
self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
def test_diffusers_model_is_compatible_variant(self): | |
filenames = [ | |
"unet/diffusion_pytorch_model.fp16.bin", | |
"unet/diffusion_pytorch_model.fp16.safetensors", | |
] | |
variant = "fp16" | |
self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
def test_diffusers_model_is_compatible_variant_partial(self): | |
# pass variant but use the non-variant filenames | |
filenames = [ | |
"unet/diffusion_pytorch_model.bin", | |
"unet/diffusion_pytorch_model.safetensors", | |
] | |
variant = "fp16" | |
self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
def test_diffusers_model_is_not_compatible_variant(self): | |
filenames = [ | |
"safety_checker/pytorch_model.fp16.bin", | |
"safety_checker/model.fp16.safetensors", | |
"vae/diffusion_pytorch_model.fp16.bin", | |
"vae/diffusion_pytorch_model.fp16.safetensors", | |
"text_encoder/pytorch_model.fp16.bin", | |
"text_encoder/model.fp16.safetensors", | |
"unet/diffusion_pytorch_model.fp16.bin", | |
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors', | |
] | |
variant = "fp16" | |
self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) | |
def test_transformer_model_is_compatible_variant(self): | |
filenames = [ | |
"text_encoder/pytorch_model.fp16.bin", | |
"text_encoder/model.fp16.safetensors", | |
] | |
variant = "fp16" | |
self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
def test_transformer_model_is_compatible_variant_partial(self): | |
# pass variant but use the non-variant filenames | |
filenames = [ | |
"text_encoder/pytorch_model.bin", | |
"text_encoder/model.safetensors", | |
] | |
variant = "fp16" | |
self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
def test_transformer_model_is_not_compatible_variant(self): | |
filenames = [ | |
"safety_checker/pytorch_model.fp16.bin", | |
"safety_checker/model.fp16.safetensors", | |
"vae/diffusion_pytorch_model.fp16.bin", | |
"vae/diffusion_pytorch_model.fp16.safetensors", | |
"text_encoder/pytorch_model.fp16.bin", | |
# 'text_encoder/model.fp16.safetensors', | |
"unet/diffusion_pytorch_model.fp16.bin", | |
"unet/diffusion_pytorch_model.fp16.safetensors", | |
] | |
variant = "fp16" | |
self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) | |