Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import numpy as np | |
from PIL import Image | |
import io | |
import requests | |
import gradio as gr | |
import replicate | |
from dotenv import load_dotenv, find_dotenv | |
# Locate the .env file | |
dotenv_path = find_dotenv() | |
load_dotenv(dotenv_path) | |
REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') | |
def image_classifier(prompt, starter_image, image_strength): | |
if starter_image is not None: | |
starter_image_pil = Image.fromarray(starter_image.astype('uint8')) | |
# Resize the starter image if either dimension is larger than 768 pixels | |
if starter_image_pil.size[0] > 512 or starter_image_pil.size[1] > 512: | |
# Calculate the new size while maintaining the aspect ratio | |
if starter_image_pil.size[0] > starter_image_pil.size[1]: | |
# Width is larger than height | |
new_width = 512 | |
new_height = int((512 / starter_image_pil.size[0]) * starter_image_pil.size[1]) | |
else: | |
# Height is larger than width | |
new_height = 512 | |
new_width = int((512 / starter_image_pil.size[1]) * starter_image_pil.size[0]) | |
# Resize the image | |
starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS) | |
# Save the starter image to a bytes buffer | |
buffered = io.BytesIO() | |
starter_image_pil.save(buffered, format="JPEG") | |
# Encode the starter image to base64 | |
starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
if starter_image is not None: | |
input = { | |
"width": 512, | |
"height": 512, | |
"prompt": prompt + " in the style of TOK", | |
#"refine": "expert_ensemble_refiner", | |
"apply_watermark": False, | |
"num_inference_steps": 25, | |
"num_outputs": 3, | |
"lora_scale": .96, | |
"image": "data:image/jpeg;base64," + starter_image_base64, | |
"prompt_strength": 1-image_strength, | |
} | |
else: | |
input = { | |
"width": 512, | |
"height": 512, | |
"prompt": prompt + " in the style of TOK", | |
#"refine": "expert_ensemble_refiner", | |
"apply_watermark": False, | |
"num_inference_steps": 25, | |
"num_outputs": 3, | |
"lora_scale": .96, | |
} | |
output = replicate.run( | |
# update to new trained model | |
"ltejedor/cmf:3af83ef60d86efbf374edb788fa4183a6067416e2fadafe709350dc1efe37d1d", | |
input=input | |
) | |
print(output) | |
# Download the image from the URL | |
image_url = output[0] | |
print(image_url) | |
response = requests.get(image_url) | |
print(response) | |
img1 = Image.open(io.BytesIO(response.content)) | |
# Download the image from the URL | |
image_url = output[1] | |
print(image_url) | |
response = requests.get(image_url) | |
print(response) | |
img2 = Image.open(io.BytesIO(response.content)) | |
# Download the image from the URL | |
image_url = output[2] | |
print(image_url) | |
response = requests.get(image_url) | |
print(response) | |
img3 = Image.open(io.BytesIO(response.content)) | |
return [img1, img2, img3] | |
# app = Flask(__name__) | |
# os.environ.get("REPLICATE_API_TOKEN") | |
# @app.route("/") | |
# def index(): | |
demo = gr.Interface(fn=image_classifier, inputs=["text", "image", gr.Slider(0, 1, step=0.025, value=0.2, label="Image Strength")], outputs=["image", "image", "image"]) | |
demo.launch(share=False) |