import numpy from PIL import Image import pytest from pytest import fixture import torch from typing import Tuple from sgm.inference.api import ( model_specs, SamplingParams, SamplingPipeline, Sampler, ModelArchitecture, ) import sgm.inference.helpers as helpers @pytest.mark.inference class TestInference: @fixture(scope="class", params=model_specs.keys()) def pipeline(self, request) -> SamplingPipeline: pipeline = SamplingPipeline(request.param) yield pipeline del pipeline torch.cuda.empty_cache() @fixture( scope="class", params=[ [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], ], ids=["SDXL_V1", "SDXL_V0_9"], ) def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: base_pipeline = SamplingPipeline(request.param[0]) refiner_pipeline = SamplingPipeline(request.param[1]) yield base_pipeline, refiner_pipeline del base_pipeline del refiner_pipeline torch.cuda.empty_cache() def create_init_image(self, h, w): image_array = numpy.random.rand(h, w, 3) * 255 image = Image.fromarray(image_array.astype("uint8")).convert("RGB") return helpers.get_input_image_tensor(image) @pytest.mark.parametrize("sampler_enum", Sampler) def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): output = pipeline.text_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, ) assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): output = pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, ) assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) @pytest.mark.parametrize( "use_init_image", [True, False], ids=["img2img", "txt2img"] ) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], sampler_enum, use_init_image, ): base_pipeline, refiner_pipeline = sdxl_pipelines if use_init_image: output = base_pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=self.create_init_image( base_pipeline.specs.height, base_pipeline.specs.width ), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, ) else: output = base_pipeline.text_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, ) assert isinstance(output, (tuple, list)) samples, samples_z = output assert samples is not None assert samples_z is not None refiner_pipeline.refiner( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=samples_z, prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, )