Spaces:
Runtime error
Runtime error
import time | |
import pytest | |
import replicate | |
def model_name(request): | |
return "stability-ai/sdxl" | |
def model(model_name): | |
return replicate.models.get(model_name) | |
def version(model): | |
versions = model.versions.list() | |
return versions[0] | |
def training(model_name, version): | |
training_input = { | |
"input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar" | |
} | |
print(f"Training on {model_name}:{version.id}") | |
return replicate.trainings.create( | |
version=model_name + ":" + version.id, | |
input=training_input, | |
destination="replicate-internal/training-scratch", | |
) | |
def prediction_tests(): | |
return [ | |
{ | |
"prompt": "A photo of TOK at the beach", | |
"refine": "expert_ensemble_refiner", | |
}, | |
] | |
def test_training(training): | |
while training.completed_at is None: | |
time.sleep(60) | |
training.reload() | |
assert training.status == "succeeded" | |
def trained_model_and_version(training): | |
trained_model, trained_version = training.output["version"].split(":") | |
return trained_model, trained_version | |
def test_post_training_predictions(trained_model_and_version, prediction_tests): | |
trained_model, trained_version = trained_model_and_version | |
model = replicate.models.get(trained_model) | |
version = model.versions.get(trained_version) | |
predictions = [ | |
replicate.predictions.create(version=version, input=val) | |
for val in prediction_tests | |
] | |
for val in predictions: | |
val.wait() | |
assert val.status == "succeeded" | |