Spaces:
Runtime error
Runtime error
import os | |
import io | |
import random | |
import requests | |
from PIL import Image | |
from dataset_viber import AnnotatorInterFace | |
HF_TOKEN = os.environ["HF_TOKEN"] | |
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} | |
DATASET_SERVER_URL = "https://datasets-server.huggingface.co" | |
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train" | |
MODEL_URL = ( | |
"https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
) | |
def retrieve_sample(idx): | |
api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1" | |
response = requests.get(api_url, headers=HEADERS) | |
data = response.json() | |
img_url = data["rows"][0]["row"]["image"]["src"] | |
prompt = data["rows"][0]["row"]["prompt"] | |
return img_url, prompt | |
def get_rows(): | |
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}" | |
response = requests.get(api_url, headers=HEADERS) | |
num_rows = response.json()["size"]["config"]["num_rows"] | |
return num_rows | |
def generate_response(prompt): | |
payload = { | |
"inputs": prompt, | |
} | |
response = requests.post(MODEL_URL, headers=HEADERS, json=payload) | |
image = Image.open(io.BytesIO(response.content)) | |
return image | |
def next_input(_prompt, _completion_a, _completion_b): | |
random_idx = random.randint(0, get_rows()) - 1 | |
img_url, prompt = retrieve_sample(random_idx) | |
generated_image = generate_response(prompt) | |
return (prompt, img_url, generated_image) | |
if __name__ == "__main__": | |
interface = AnnotatorInterFace.for_image_generation_preference( | |
fn=next_input, | |
dataset_name=None, | |
) | |
interface.launch() | |