burtenshaw HF staff commited on
Commit
ce11ffc
·
1 Parent(s): 4c07fb9

first commit

Browse files
Files changed (3) hide show
  1. .vscode/launch.json +17 -0
  2. app.py +54 -0
  3. requirements.txt +19 -0
.vscode/launch.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: Current File",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal",
13
+ "envFile": "${workspaceFolder}/.env",
14
+ "python": "${workspaceFolder}/../data-viber/.venv/bin/python"
15
+ }
16
+ ]
17
+ }
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import random
4
+
5
+ import requests
6
+ from PIL import Image
7
+ from data_viber import AnnotatorInterFace
8
+
9
+ HF_TOKEN = os.environ["HF_TOKEN"]
10
+ HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
11
+ DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
12
+ DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
13
+ MODEL_URL = (
14
+ "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
15
+ )
16
+
17
+
18
+ def retrieve_sample(idx):
19
+ api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
20
+ response = requests.get(api_url, headers=HEADERS)
21
+ data = response.json()
22
+ img_url = data["rows"][0]["row"]["image"]["src"]
23
+ prompt = data["rows"][0]["row"]["prompt"]
24
+ return img_url, prompt
25
+
26
+
27
+ def get_rows():
28
+ api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
29
+ response = requests.get(api_url, headers=HEADERS)
30
+ num_rows = response.json()["size"]["config"]["num_rows"]
31
+ return num_rows
32
+
33
+
34
+ def generate_response(prompt):
35
+ payload = {
36
+ "inputs": prompt,
37
+ }
38
+ response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
39
+ image = Image.open(io.BytesIO(response.content))
40
+ return image
41
+
42
+
43
+ def next_input(_prompt, _completion_a, _completion_b):
44
+ random_idx = random.randint(0, get_rows()) - 1
45
+ img_url, prompt = retrieve_sample(random_idx)
46
+ generated_image = generate_response(prompt)
47
+ return (prompt, img_url, generated_image)
48
+
49
+ if __name__ == "__main__":
50
+ interface = AnnotatorInterFace.for_image_generation_preference(
51
+ fn=next_input,
52
+ dataset_name=None,
53
+ )
54
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio[oauth]>=4.3,<5
2
+ datasets>=2,<3
3
+ sentence-transformers>=3
4
+ optimum[onnxruntime]>=1.21.3
5
+ tabulate>=0.9.0
6
+ diffusers>=0.30.0
7
+ transformers>=4.43.4
8
+ ipykernel>=6.29.5
9
+ umap-learn>=0.5,<1
10
+ plotly>=5,<6
11
+ dash>=2.11,<3
12
+ dash-bootstrap-components>=1.6.0
13
+ pre-commit>=3.8.0
14
+ ruff>=0.5,<1
15
+ pytest>=8,<9
16
+ black>=24,<25
17
+ openpyxl>=3,<4
18
+ requests
19
+ pillow