burtenshaw's picture
burtenshaw HF staff
first commit
ce11ffc
raw
history blame
1.62 kB
import os
import io
import random
import requests
from PIL import Image
from data_viber import AnnotatorInterFace
HF_TOKEN = os.environ["HF_TOKEN"]
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
MODEL_URL = (
"https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
)
def retrieve_sample(idx):
api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
response = requests.get(api_url, headers=HEADERS)
data = response.json()
img_url = data["rows"][0]["row"]["image"]["src"]
prompt = data["rows"][0]["row"]["prompt"]
return img_url, prompt
def get_rows():
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
response = requests.get(api_url, headers=HEADERS)
num_rows = response.json()["size"]["config"]["num_rows"]
return num_rows
def generate_response(prompt):
payload = {
"inputs": prompt,
}
response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
image = Image.open(io.BytesIO(response.content))
return image
def next_input(_prompt, _completion_a, _completion_b):
random_idx = random.randint(0, get_rows()) - 1
img_url, prompt = retrieve_sample(random_idx)
generated_image = generate_response(prompt)
return (prompt, img_url, generated_image)
if __name__ == "__main__":
interface = AnnotatorInterFace.for_image_generation_preference(
fn=next_input,
dataset_name=None,
)
interface.launch()