diffusers-gallery-bot / classifier.py
radames's picture
classification code
42e31b9
raw
history blame
2.03 kB
import os
import re
import requests
import json
import subprocess
from io import BytesIO
import uuid
from math import ceil
from tqdm import tqdm
from pathlib import Path
from db import Database
DB_FOLDER = Path("diffusers-gallery-data")
database = Database(DB_FOLDER)
CLASSIFIER_URL = "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference"
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
def main():
with database.get_db() as db:
cursor = db.cursor()
cursor.execute("""
SELECT *
FROM models
""")
results = list(cursor.fetchall())
for row in tqdm(results):
row_id = row['id']
# keep json data on row_data
row_data = json.loads(row['data'])
print("updating row", row_id)
images = row_data['images']
# filter nones
images = [i for i in images if i is not None]
if len(images) > 0:
# classifying only the first image
images_urls = [ASSETS_URL + images[0]]
response = requests.post(CLASSIFIER_URL, json={"data": [
{"urls": images_urls}, # json urls: list of images urls
False, # enable/disable gallery image output
None, # single image input
None, # files input
]}).json()
# data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery],
class_data = response['data'][0][0]
class_data_parsed = {row['label']: round(
row['score'], 3) for row in class_data}
# update row data with classificator data
row_data['class'] = class_data_parsed
else:
row_data['class'] = {}
with database.get_db() as db:
cursor = db.cursor()
cursor.execute("UPDATE models SET data = ? WHERE id = ?",
[json.dumps(row_data), row_id])
db.commit()
if __name__ == "__main__":
main()