Spaces:
Running
Running
import datetime | |
import gradio as gr | |
import requests | |
import random | |
import io | |
import zipfile | |
from PIL import Image | |
import os | |
import numpy as np | |
import json | |
import boto3 | |
# Create an S3 client | |
s3 = boto3.client('s3') | |
def save_to_s3(image_data, payload, file_name): | |
# Define the bucket and the path | |
bucket_name = 'dataset-novelai' | |
folder_name = datetime.datetime.now().strftime("%Y-%m-%d") | |
image_key = f'gradio/{folder_name}/{file_name}.webp' | |
payload_key = f'gradio/{folder_name}/{file_name}.json' | |
# Save the image | |
image_data.seek(0) # Go to the start of the BytesIO object | |
s3.upload_fileobj(image_data, bucket_name, image_key, ExtraArgs={'ContentType': 'image/webp'}) | |
# Save the payload | |
payload_data = io.BytesIO(payload.encode('utf-8')) | |
s3.upload_fileobj(payload_data, bucket_name, payload_key, ExtraArgs={'ContentType': 'application/json'}) | |
# Function to handle the NovelAI API request | |
def generate_novelai_image(input_text, quality_tags, seed, negative_prompt, scale, ratio, sampler): | |
jwt_token = os.environ.get('NAI_API_KEY') | |
if ratio == "Landscape (1216x832)": | |
width = 1216 | |
height = 832 | |
elif ratio == "Square (1024x1024)": | |
width = 1024 | |
height = 1024 | |
elif ratio == "Portrait (832x1216)": | |
width = 832 | |
height = 1216 | |
# Check if quality tags are provided and append to input | |
final_input = input_text | |
if quality_tags: | |
final_input += ", " + quality_tags | |
# Assign a random seed if seed is -1 | |
if seed == -1: | |
seed = random.randint(0, 2**32 - 1) | |
# Define the API URL | |
url = "https://api.novelai.net/ai/generate-image" | |
# Set the headers | |
headers = { | |
"Authorization": f"Bearer {jwt_token}", | |
"Content-Type": "application/json", | |
"Origin": "https://novelai.net", | |
"Referer": "https://novelai.net/" | |
} | |
# Define the payload | |
payload = { | |
"action": "generate", | |
"input": final_input, | |
"model": "nai-diffusion-3", | |
"parameters": { | |
"width": width, | |
"height": height, | |
"scale": scale, | |
"sampler": sampler, | |
"steps": 28, | |
"n_samples": 1, | |
"ucPreset": 0, | |
"add_original_image": False, | |
"cfg_rescale": 0, | |
"controlnet_strength": 1, | |
"dynamic_thresholding": False, | |
"legacy": False, | |
"negative_prompt": negative_prompt, | |
"noise_schedule": "native", | |
"qualityToggle": True, | |
"seed": seed, | |
"sm": False, | |
"sm_dyn": False, | |
"ucPreset": 0, | |
"uncond_scale": 1, | |
} | |
} | |
# Send the POST request | |
response = requests.post(url, json=payload, headers=headers) | |
# Process the response | |
if response.headers.get('Content-Type') == 'application/x-zip-compressed': | |
zipfile_in_memory = io.BytesIO(response.content) | |
with zipfile.ZipFile(zipfile_in_memory, 'r') as zip_ref: | |
file_names = zip_ref.namelist() | |
if file_names: | |
with zip_ref.open(file_names[0]) as file: | |
image = Image.open(file) | |
# Prepare to save the image to S3 | |
buffered = io.BytesIO() | |
image.save(buffered, format="WEBP", quality=98) | |
file_name = str(int(datetime.datetime.now().timestamp())) | |
save_to_s3(buffered, json.dumps(payload, indent=4), file_name) | |
return np.array(image), json.dumps(payload, indent=4) | |
else: | |
return "No images found in the zip file.", json.dumps(payload, indent=4) | |
else: | |
return "The response is not a zip file.", json.dumps(payload, indent=4) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_novelai_image, | |
inputs=[ | |
gr.Textbox(label="Input Text"), | |
gr.Textbox(label="Quality Tags", value="best quality, amazing quality, very aesthetic, absurdres"), | |
gr.Slider(minimum=-1, maximum=2**32 - 1, step=1, value=-1, label="Seed"), | |
gr.Textbox(label="Negative Prompt", value="nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"), | |
gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Scale"), | |
gr.Radio(choices=["Landscape (1216x832)", "Square (1024x1024)", "Portrait (832x1216)"], value="Portrait (832x1216)"), | |
gr.Dropdown( | |
choices=[ | |
"k_euler", "k_euler_ancestral", "k_dpmpp_2s_ancestral", | |
"k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3" | |
], | |
value="k_euler", | |
label="Sampler" | |
) | |
], | |
outputs=[ | |
"image", | |
gr.Textbox(label="Submitted Payload") | |
] | |
) | |
try: | |
iface.queue(concurrency_count=3).launch(share=True) | |
except RuntimeError: # use in HF spaces | |
iface.queue(concurrency_count=3).launch() |