File size: 2,017 Bytes
2afcb7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
926ff6c
 
2afcb7e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

from argparse import ArgumentParser

import datasets
import gradio as gr
import numpy as np
import openai

from dataset_creation.generate_txt_dataset import generate


def main(openai_model: str):
    dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
    captions = dataset[np.random.permutation(len(dataset))]["TEXT"]
    index = 0

    def click_random():
        nonlocal index
        output = captions[index]
        index = (index + 1) % len(captions)
        return output

    def click_generate(input: str):
        if input == "":
            raise gr.Error("Input caption is missing!")
        edit_output = generate(openai_model, input)
        if edit_output is None:
            return "Failed :(", "Failed :("
        return edit_output

    with gr.Blocks(css="footer {visibility: hidden}") as demo:
        txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...")  # fmt: skip
        txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False)
        txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False)

        with gr.Row():
            clear_btn = gr.Button("Clear")
            random_btn = gr.Button("Random Input")
            generate_btn = gr.Button("Generate Instruction + Edited Caption")

            clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output])
            random_btn.click(fn=click_random, inputs=[], outputs=[txt_input])
            generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output])

    demo.launch(share=True)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--openai-api-key", required=True, type=str)
    parser.add_argument("--openai-model", required=True, type=str)
    args = parser.parse_args()
    openai.api_key = args.openai_api_key
    main(args.openai_model)