Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import json | |
from flask import Flask, request, jsonify, g | |
from flask_expects_json import expects_json | |
from flask_cors import CORS | |
from PIL import Image | |
from huggingface_hub import Repository | |
from flask_apscheduler import APScheduler | |
import shutil | |
import sqlite3 | |
import subprocess | |
from jsonschema import ValidationError | |
MODE = os.environ.get('FLASK_ENV', 'production') | |
IS_DEV = MODE == 'development' | |
app = Flask(__name__, static_url_path='/static') | |
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False | |
schema = { | |
"type": "object", | |
"properties": { | |
"prompt": {"type": "string"}, | |
"images": { | |
"type": "array", | |
"items": { | |
"type": "object", | |
"minProperties": 2, | |
"maxProperties": 2, | |
"properties": { | |
"colors": { | |
"type": "array", | |
"items": { | |
"type": "string" | |
}, | |
"maxItems": 5, | |
"minItems": 5 | |
}, | |
"imgURL": {"type": "string"}} | |
} | |
} | |
}, | |
"minProperties": 2, | |
"maxProperties": 2 | |
} | |
CORS(app) | |
DB_FILE = Path("./data.db") | |
TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') | |
repo = Repository( | |
local_dir="data", | |
repo_type="dataset", | |
clone_from="huggingface-projects/color-palettes-sd", | |
use_auth_token=TOKEN | |
) | |
repo.git_pull() | |
# copy db on db to local path | |
shutil.copyfile("./data/data.db", DB_FILE) | |
db = sqlite3.connect(DB_FILE) | |
try: | |
data = db.execute("SELECT * FROM palettes").fetchall() | |
if IS_DEV: | |
print(f"Loaded {len(data)} palettes from local db") | |
db.close() | |
except sqlite3.OperationalError: | |
db.execute( | |
'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)') | |
db.commit() | |
def get_db(): | |
db = getattr(g, '_database', None) | |
if db is None: | |
db = g._database = sqlite3.connect(DB_FILE) | |
db.row_factory = sqlite3.Row | |
return db | |
def close_connection(exception): | |
db = getattr(g, '_database', None) | |
if db is not None: | |
db.close() | |
def update_repository(): | |
repo.git_pull() | |
# copy db on db to local path | |
shutil.copyfile(DB_FILE, "./data/data.db") | |
with sqlite3.connect("./data/data.db") as db: | |
db.row_factory = sqlite3.Row | |
palettes = db.execute("SELECT * FROM palettes").fetchall() | |
data = [{'id': row['id'], 'data': json.loads( | |
row['data']), 'created_at': row['created_at']} for row in palettes] | |
with open('./data/data.json', 'w') as f: | |
json.dump(data, f, separators=(',', ':')) | |
print("Updating repository") | |
subprocess.Popen( | |
"git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True) | |
repo.push_to_hub(blocking=False) | |
def index(): | |
return app.send_static_file('index.html') | |
def push(): | |
if (request.headers['token'] == TOKEN): | |
update_repository() | |
return jsonify({'success': True}) | |
else: | |
return "Error", 401 | |
def getAllData(): | |
palettes = get_db().execute("SELECT * FROM palettes").fetchall() | |
data = [{'id': row['id'], 'data': json.loads( | |
row['data']), 'created_at': row['created_at']} for row in palettes] | |
return data | |
def getdata(): | |
return jsonify(getAllData()) | |
def create(): | |
data = g.data | |
db = get_db() | |
cursor = db.cursor() | |
cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) | |
db.commit() | |
return jsonify(getAllData()) | |
def bad_request(error): | |
if isinstance(error.description, ValidationError): | |
original_error = error.description | |
return jsonify({'error': original_error.message}), 400 | |
return error | |
if __name__ == '__main__': | |
if not IS_DEV: | |
print("Starting scheduler -- Running Production") | |
scheduler = APScheduler() | |
scheduler.add_job(id='Update Dataset Repository', | |
func=update_repository, trigger='interval', hours=1) | |
scheduler.start() | |
else: | |
print("Not Starting scheduler -- Running Development") | |
app.run(host='0.0.0.0', port=int( | |
os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV) | |