import datetime import gradio as gr import requests import random import io import zipfile from PIL import Image import os import numpy as np import json import boto3 # Create an S3 client s3 = boto3.client('s3') def save_to_s3(image_data, payload, file_name): # Define the bucket and the path bucket_name = 'dataset-novelai' folder_name = datetime.datetime.now().strftime("%Y-%m-%d") image_key = f'gradio/{folder_name}/{file_name}.webp' payload_key = f'gradio/{folder_name}/{file_name}.json' # Save the image image_data.seek(0) # Go to the start of the BytesIO object s3.upload_fileobj(image_data, bucket_name, image_key, ExtraArgs={'ContentType': 'image/webp'}) # Save the payload payload_data = io.BytesIO(payload.encode('utf-8')) s3.upload_fileobj(payload_data, bucket_name, payload_key, ExtraArgs={'ContentType': 'application/json'}) # Function to handle the NovelAI API request def generate_novelai_image(input_text, quality_tags, seed, negative_prompt, scale, ratio, sampler): jwt_token = os.environ.get('NAI_API_KEY') if ratio == "Landscape (1216x832)": width = 1216 height = 832 elif ratio == "Square (1024x1024)": width = 1024 height = 1024 elif ratio == "Portrait (832x1216)": width = 832 height = 1216 # Check if quality tags are provided and append to input final_input = input_text if quality_tags: final_input += ", " + quality_tags # Assign a random seed if seed is -1 if seed == -1: seed = random.randint(0, 2**32 - 1) # Define the API URL url = "https://api.novelai.net/ai/generate-image" # Set the headers headers = { "Authorization": f"Bearer {jwt_token}", "Content-Type": "application/json", "Origin": "https://novelai.net", "Referer": "https://novelai.net/" } # Define the payload payload = { "action": "generate", "input": final_input, "model": "nai-diffusion-3", "parameters": { "width": width, "height": height, "scale": scale, "sampler": sampler, "steps": 28, "n_samples": 1, "ucPreset": 0, "add_original_image": False, "cfg_rescale": 0, "controlnet_strength": 1, "dynamic_thresholding": False, "legacy": False, "negative_prompt": negative_prompt, "noise_schedule": "native", "qualityToggle": True, "seed": seed, "sm": False, "sm_dyn": False, "ucPreset": 0, "uncond_scale": 1, } } # Send the POST request response = requests.post(url, json=payload, headers=headers) # Process the response if response.headers.get('Content-Type') == 'application/x-zip-compressed': zipfile_in_memory = io.BytesIO(response.content) with zipfile.ZipFile(zipfile_in_memory, 'r') as zip_ref: file_names = zip_ref.namelist() if file_names: with zip_ref.open(file_names[0]) as file: image = Image.open(file) # Prepare to save the image to S3 buffered = io.BytesIO() image.save(buffered, format="WEBP", quality=98) file_name = str(int(datetime.datetime.now().timestamp())) save_to_s3(buffered, json.dumps(payload, indent=4), file_name) return np.array(image), json.dumps(payload, indent=4) else: return "No images found in the zip file.", json.dumps(payload, indent=4) else: return "The response is not a zip file.", json.dumps(payload, indent=4) # Create Gradio interface iface = gr.Interface( fn=generate_novelai_image, inputs=[ gr.Textbox(label="Input Text"), gr.Textbox(label="Quality Tags", value="best quality, amazing quality, very aesthetic, absurdres"), gr.Slider(minimum=-1, maximum=2**32 - 1, step=1, value=-1, label="Seed"), gr.Textbox(label="Negative Prompt", value="nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"), gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Scale"), gr.Radio(choices=["Landscape (1216x832)", "Square (1024x1024)", "Portrait (832x1216)"], value="Portrait (832x1216)"), gr.Dropdown( choices=[ "k_euler", "k_euler_ancestral", "k_dpmpp_2s_ancestral", "k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3" ], value="k_euler", label="Sampler" ) ], outputs=[ "image", gr.Textbox(label="Submitted Payload") ] ) try: iface.queue(concurrency_count=3).launch(share=True) except RuntimeError: # use in HF spaces iface.queue(concurrency_count=3).launch()