File size: 3,544 Bytes
367eecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import base64
import numpy as np
from PIL import Image
import io
import requests
import gradio as gr

import replicate

from dotenv import load_dotenv, find_dotenv

# Locate the .env file
dotenv_path = find_dotenv()

load_dotenv(dotenv_path)

REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN')


def image_classifier(prompt, starter_image, image_strength):
        
    if starter_image is not None:
        starter_image_pil = Image.fromarray(starter_image.astype('uint8'))

        # Resize the starter image if either dimension is larger than 768 pixels
        if starter_image_pil.size[0] > 512 or starter_image_pil.size[1] > 512:
            # Calculate the new size while maintaining the aspect ratio
            if starter_image_pil.size[0] > starter_image_pil.size[1]:
                # Width is larger than height
                new_width = 512
                new_height = int((512 / starter_image_pil.size[0]) * starter_image_pil.size[1])
            else:
                # Height is larger than width
                new_height = 512
                new_width = int((512 / starter_image_pil.size[1]) * starter_image_pil.size[0])
            
            # Resize the image
            starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS)

                # Save the starter image to a bytes buffer
            buffered = io.BytesIO()
            starter_image_pil.save(buffered, format="JPEG")
            
            # Encode the starter image to base64
            starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

    if starter_image is not None:
        input = {
            "width": 512,
            "height": 512,
            "prompt": prompt + " in the style of TOK",
            #"refine": "expert_ensemble_refiner",
            "apply_watermark": False,
            "num_inference_steps": 25,
            "num_outputs": 3,
            "lora_scale": .96,
            "image": "data:image/jpeg;base64," + starter_image_base64, 
            "prompt_strength": 1-image_strength,
        }
    else:
        input = {
            "width": 512,
            "height": 512,
            "prompt": prompt + " in the style of TOK",
            #"refine": "expert_ensemble_refiner",
            "apply_watermark": False,
            "num_inference_steps": 25,
            "num_outputs": 3,
            "lora_scale": .96,
        }
    
    output = replicate.run(
        # update to new trained model
        "ltejedor/cmf:3af83ef60d86efbf374edb788fa4183a6067416e2fadafe709350dc1efe37d1d",
        input=input
    )

    print(output)
    
    # Download the image from the URL
    image_url = output[0]
    print(image_url)
    response = requests.get(image_url)
    print(response)
    img1 = Image.open(io.BytesIO(response.content))

    # Download the image from the URL
    image_url = output[1]
    print(image_url)
    response = requests.get(image_url)
    print(response)
    img2 = Image.open(io.BytesIO(response.content))

    # Download the image from the URL
    image_url = output[2]
    print(image_url)
    response = requests.get(image_url)
    print(response)
    img3 = Image.open(io.BytesIO(response.content))

    return [img1, img2, img3]


# app = Flask(__name__)
# os.environ.get("REPLICATE_API_TOKEN")

# @app.route("/")
# def index():

demo = gr.Interface(fn=image_classifier, inputs=["text", "image", gr.Slider(0, 1, step=0.025, value=0.2, label="Image Strength")], outputs=["image", "image", "image"])
demo.launch(share=False)