|
import os |
|
import io |
|
import random |
|
|
|
import requests |
|
from PIL import Image |
|
from data_viber import AnnotatorInterFace |
|
|
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} |
|
DATASET_SERVER_URL = "https://datasets-server.huggingface.co" |
|
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train" |
|
MODEL_URL = ( |
|
"https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" |
|
) |
|
|
|
|
|
def retrieve_sample(idx): |
|
api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1" |
|
response = requests.get(api_url, headers=HEADERS) |
|
data = response.json() |
|
img_url = data["rows"][0]["row"]["image"]["src"] |
|
prompt = data["rows"][0]["row"]["prompt"] |
|
return img_url, prompt |
|
|
|
|
|
def get_rows(): |
|
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}" |
|
response = requests.get(api_url, headers=HEADERS) |
|
num_rows = response.json()["size"]["config"]["num_rows"] |
|
return num_rows |
|
|
|
|
|
def generate_response(prompt): |
|
payload = { |
|
"inputs": prompt, |
|
} |
|
response = requests.post(MODEL_URL, headers=HEADERS, json=payload) |
|
image = Image.open(io.BytesIO(response.content)) |
|
return image |
|
|
|
|
|
def next_input(_prompt, _completion_a, _completion_b): |
|
random_idx = random.randint(0, get_rows()) - 1 |
|
img_url, prompt = retrieve_sample(random_idx) |
|
generated_image = generate_response(prompt) |
|
return (prompt, img_url, generated_image) |
|
|
|
if __name__ == "__main__": |
|
interface = AnnotatorInterFace.for_image_generation_preference( |
|
fn=next_input, |
|
dataset_name=None, |
|
) |
|
interface.launch() |
|
|