|
import gradio as gr |
|
import os, sys |
|
from prompting import promptingutils |
|
from imageprocessing import imageprocessingtools |
|
from openai import OpenAI |
|
from prompting.promptingutils import DEFAULT_N_SAMPLES, DEFAULT_OBJECT_THRESHOLD, DEFAULT_RANDOM_STATE |
|
|
|
|
|
AVAILABLE_LLMS = [ |
|
"vicuna-7b", |
|
"llama-7b-chat", |
|
"mistral-7b-instruct", |
|
"vicuna-13b", |
|
] |
|
|
|
DEFAULT_TEMPERATURE = 0 |
|
LLAMA_API_TOKEN = os.environ["LLAMA_API_TOKEN"] |
|
|
|
client = OpenAI( |
|
api_key = LLAMA_API_TOKEN, |
|
base_url = "https://api.llama-api.com" |
|
) |
|
|
|
|
|
def caption_artwork( |
|
image_filepath: os.PathLike, |
|
llm :str, |
|
temperature = DEFAULT_TEMPERATURE, |
|
items_threshold = DEFAULT_OBJECT_THRESHOLD, |
|
random_state = DEFAULT_RANDOM_STATE, |
|
n_samples_per_emotion = DEFAULT_N_SAMPLES |
|
)-> tuple: |
|
|
|
all_information = imageprocessingtools.extract_all_information_from_image(image_filepath) |
|
|
|
emotion = all_information["emotion"] |
|
colors_list = all_information["colors_list"] |
|
objects_and_probs = all_information["objects_and_probs"] |
|
objects_list = promptingutils.filter_items(objects_and_probs, items_threshold=items_threshold) |
|
|
|
user_prompt = promptingutils.get_user_prompt( |
|
colors_list=colors_list, |
|
objects_list=objects_list, |
|
emotion=emotion, |
|
n_samples_per_emotion=n_samples_per_emotion, |
|
random_state=random_state, |
|
object_threshold=items_threshold |
|
|
|
) |
|
|
|
response = client.chat.completions.create( |
|
model = llm, |
|
messages = [ |
|
{"role": "system" , "content": "Assistant is a large language model trained by OpenAI."}, |
|
{"role": "user" , "content": user_prompt} |
|
], |
|
temperature = temperature |
|
) |
|
|
|
commentary_str = response.choices[0].message.content |
|
colors_str = ", ".join(colors_list) |
|
objects_str = ", ".join(objects_list) |
|
emotion_str = emotion |
|
|
|
return (emotion_str, colors_str, objects_str, commentary_str) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Column(): |
|
gr.HTML(""" |
|
<h2 style="text-align: center;"> |
|
LLMs talk about art! |
|
</h2> |
|
<p style="text-align: center;"></p> |
|
""") |
|
gr_image = gr.Image( |
|
label= "An artwork: ", |
|
value="./1665_Girl_with_a_Pearl_Earring.jpg" |
|
) |
|
|
|
gr_model = gr.Dropdown( |
|
label= "A Large Language Model", |
|
choices = AVAILABLE_LLMS, |
|
value = AVAILABLE_LLMS[0] |
|
) |
|
|
|
gr_emotion = gr.Textbox( |
|
label = "Evoked emotion: ", |
|
) |
|
gr_colors = gr.Textbox( |
|
label = "Main colors: ", |
|
) |
|
gr_objects = gr.Textbox( |
|
label = "Main objects present: ", |
|
) |
|
|
|
gr_commentary = gr.Textbox( |
|
label = "Commentary on the artwork:", |
|
) |
|
|
|
|
|
|
|
btn = gr.Button(value="Submit your image!") |
|
btn.click( |
|
caption_artwork, |
|
inputs=[gr_image, gr_model], |
|
outputs=[gr_emotion, gr_colors, gr_objects, gr_commentary] |
|
) |
|
|
|
""" |
|
def greet(name): |
|
return "Hello " + name + "!!" |
|
|
|
iface = gr.Interface(fn=greet, inputs="text", outputs="text") |
|
iface.launch() |
|
""" |
|
if __name__ == "__main__": |
|
demo.launch(allowed_paths = [os.path.dirname(__file__)]) |