emoji-sdxl-old-diffusers / cog_sdxl /tests /test_remote_train.py
multimodalart's picture
Upload 52 files
18d0601 verified
import time
import pytest
import replicate
@pytest.fixture(scope="module")
def model_name(request):
return "stability-ai/sdxl"
@pytest.fixture(scope="module")
def model(model_name):
return replicate.models.get(model_name)
@pytest.fixture(scope="module")
def version(model):
versions = model.versions.list()
return versions[0]
@pytest.fixture(scope="module")
def training(model_name, version):
training_input = {
"input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar"
}
print(f"Training on {model_name}:{version.id}")
return replicate.trainings.create(
version=model_name + ":" + version.id,
input=training_input,
destination="replicate-internal/training-scratch",
)
@pytest.fixture(scope="module")
def prediction_tests():
return [
{
"prompt": "A photo of TOK at the beach",
"refine": "expert_ensemble_refiner",
},
]
def test_training(training):
while training.completed_at is None:
time.sleep(60)
training.reload()
assert training.status == "succeeded"
@pytest.fixture(scope="module")
def trained_model_and_version(training):
trained_model, trained_version = training.output["version"].split(":")
return trained_model, trained_version
def test_post_training_predictions(trained_model_and_version, prediction_tests):
trained_model, trained_version = trained_model_and_version
model = replicate.models.get(trained_model)
version = model.versions.get(trained_version)
predictions = [
replicate.predictions.create(version=version, input=val)
for val in prediction_tests
]
for val in predictions:
val.wait()
assert val.status == "succeeded"