import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from fromage import models
from fromage import utils
import gradio as gr
import huggingface_hub
import tempfile
class FromageChatBot:
def __init__(self):
# Download model from HF Hub.
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
self.model = models.load_fromage('./', args_path, ckpt_path)
self.chat_history = ''
self.input_image = None
def reset(self):
self.chat_history = ""
self.input_image = None
return [], []
def upload_image(self, state, image_input):
state += [(f"![](/file={image_input.name})", ":)")]
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
return state, state
def save_image_to_local(self, image: Image.Image):
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
filename = next(tempfile._get_candidate_names()) + '.png'
image.save(filename)
return filename
def generate_for_prompt(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
input_prompt = 'Q: ' + input_text + '\nA:'
self.chat_history += input_prompt
# If an image was uploaded, prepend it to the model.
model_inputs = None
if self.input_image is not None:
model_inputs = [self.input_image, self.chat_history]
else:
model_inputs = [self.chat_history]
model_outputs = self.model.generate_for_images_and_texts(model_inputs, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
im_names = []
response = ''
text_outputs = []
for output in model_outputs:
if type(output) == str:
text_outputs.append(output)
response += output
elif type(output) == list:
for image in output:
filename = self.save_image_to_local(image)
response += f''
elif type(output) == Image.Image:
filename = self.save_image_to_local(output)
response += f''
self.chat_history += ' '.join(text_output)
if self.chat_history[-1] != '\n':
self.chat_history += '\n'
self.input_image = None
state.append((input_text, response))
return state, state
def launch(self):
with gr.Blocks(css="#fromage-space {height:600px; overflow-y:auto;}") as demo:
chatbot = gr.Chatbot(elem_id="fromage-space")
gr_state = gr.State([])
with gr.Row():
with gr.Column(scale=0.85):
text_input = gr.Textbox(show_label=False, placeholder="Upload an image [optional]. Then enter a text prompt, and press enter!").style(container=False)
with gr.Column(scale=0.15, min_width=0):
image_btn = gr.UploadButton("Image", file_types=["image"])
with gr.Row():
with gr.Column(scale=0.20, min_width=0):
clear_btn = gr.Button("Clear")
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
clear_btn.click(self.reset, [], [gr_state, chatbot])
demo.launch(share=False, server_name="0.0.0.0")
chatbot = FromageChatBot()
chatbot.launch()