ConZIC / app.py
jiaqingj's picture
Update app.py
d241223
from utils import create_logger, set_seed, format_output
import os
import time
import argparse
import json
from PIL import Image
import torch
import gradio as gr
import nltk
from clip import CLIP
from gen_utils import generate_caption
from control_gen_utils import control_generate_caption
from transformers import AutoModelForMaskedLM, AutoTokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=1, help = "Only supports batch_size=1 currently.")
parser.add_argument("--device", type=str,
default='cpu',choices=['cuda','cpu'])
## Generation and Controllable Type
parser.add_argument('--run_type',
default='caption',
nargs='?',
choices=['caption', 'controllable'])
parser.add_argument('--prompt',
default='Image of a',type=str)
parser.add_argument('--order',
default='shuffle',
nargs='?',
choices=['sequential', 'shuffle', 'span', 'random','parallel'],
help="Generation order of text")
parser.add_argument('--control_type',
default='sentiment',
nargs='?',
choices=["sentiment","pos"],
help="which controllable task to conduct")
parser.add_argument('--pos_type', type=list,
default=[['DET'], ['ADJ','NOUN'], ['NOUN'],
['VERB'], ['VERB'],['ADV'], ['ADP'],
['DET','NOUN'], ['NOUN'], ['NOUN','.'],
['.','NOUN'],['.','NOUN']],
help="predefined part-of-speech templete")
parser.add_argument('--sentiment_type',
default="positive",
nargs='?',
choices=["positive", "negative"])
parser.add_argument('--samples_num',
default=2,type=int)
## Hyperparameters
parser.add_argument("--sentence_len", type=int, default=10)
parser.add_argument("--candidate_k", type=int, default=200)
parser.add_argument("--alpha", type=float, default=0.02, help="weight for fluency")
parser.add_argument("--beta", type=float, default=2.0, help="weight for image-matching degree")
parser.add_argument("--gamma", type=float, default=5.0, help="weight for controllable degree")
parser.add_argument("--lm_temperature", type=float, default=0.1)
parser.add_argument("--num_iterations", type=int, default=1, help="predefined iterations for Gibbs Sampling")
## Models and Paths
parser.add_argument("--lm_model", type=str, default='bert-base-uncased',
help="Path to language model") # bert,roberta
parser.add_argument("--match_model", type=str, default='openai/clip-vit-base-patch32',
help="Path to Image-Text model") # clip,align
parser.add_argument("--caption_img_path", type=str, default='./examples/girl.jpg',
help="file path of the image for captioning")
parser.add_argument("--stop_words_path", type=str, default='stop_words.txt',
help="Path to stop_words.txt")
parser.add_argument("--add_extra_stopwords", type=list, default=[],
help="you can add some extra stop words")
args = parser.parse_args()
return args
def run_caption(args, image, lm_model, lm_tokenizer, clip, token_mask, logger):
FinalCaptionList = []
BestCaptionList = []
# logger.info(f"Processing: {image_path}")
image_instance = image.convert("RGB")
for sample_id in range(args.samples_num):
logger.info(f"Sample {sample_id}: ")
gen_texts, clip_scores = generate_caption(lm_model, clip, lm_tokenizer, image_instance, token_mask, logger,
prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len,
top_k=args.candidate_k, temperature=args.lm_temperature,
max_iter=args.num_iterations,alpha=args.alpha,beta=args.beta,
generate_order = args.order)
FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2]
BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1]
FinalCaptionList.append(FinalCaptionStr)
BestCaptionList.append(BestCaptionStr)
return FinalCaptionList, BestCaptionList
def run_control(run_type, args, image, lm_model, lm_tokenizer, clip, token_mask, logger):
FinalCaptionList = []
BestCaptionList = []
# logger.info(f"Processing: {image_path}")
image_instance = image.convert("RGB")
for sample_id in range(args.samples_num):
logger.info(f"Sample {sample_id}: ")
gen_texts, clip_scores = control_generate_caption(lm_model, clip, lm_tokenizer, image_instance, token_mask, logger,
prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len,
top_k=args.candidate_k, temperature=args.lm_temperature,
max_iter=args.num_iterations, alpha=args.alpha,
beta=args.beta, gamma=args.gamma,
ctl_type = args.control_type, style_type=args.sentiment_type,pos_type=args.pos_type, generate_order=args.order)
FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2]
BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1]
FinalCaptionList.append(FinalCaptionStr)
BestCaptionList.append(BestCaptionStr)
return FinalCaptionList, BestCaptionList
def Demo(RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img):
args = get_args()
set_seed(args.seed)
args.num_iterations = NumIterations
args.sentence_len = Length
args.run_type = RunType
args.control_type = ControlType
args.sentiment_type = SentimentType
args.alpha = Alpha
args.beta = Beta
args.gamma = Gamma
args.samples_num = SamplesNum
args.order = Order
img = Img
run_type = "caption" if args.run_type=="caption" else args.control_type
if run_type=="sentiment":
run_type = args.sentiment_type
if os.path.exists("logger")== False:
os.mkdir("logger")
logger = create_logger(
"logger",'demo_{}_{}_len{}_topk{}_alpha{}_beta{}_gamma{}_lmtemp{}_{}.log'.format(
run_type, args.order,args.sentence_len,
args.candidate_k, args.alpha,args.beta,args.gamma,args.lm_temperature,
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))
logger.info(f"Generating order:{args.order}")
logger.info(f"Run type:{run_type}")
logger.info(args)
# Load pre-trained model (weights)
lm_model = AutoModelForMaskedLM.from_pretrained(args.lm_model)
lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model)
lm_model.eval()
clip = CLIP(args.match_model)
clip.eval()
lm_model = lm_model.to(args.device)
clip = clip.to(args.device)
## Remove stop words, token mask
with open(args.stop_words_path,'r',encoding='utf-8') as stop_words_file:
stop_words = stop_words_file.readlines()
stop_words_ = [stop_word.rstrip('\n') for stop_word in stop_words]
stop_words_ += args.add_extra_stopwords
stop_ids = lm_tokenizer.convert_tokens_to_ids(stop_words_)
token_mask = torch.ones((1,lm_tokenizer.vocab_size))
for stop_id in stop_ids:
token_mask[0,stop_id]=0
token_mask = token_mask.to(args.device)
if args.run_type == 'caption':
FinalCaption, BestCaption = run_caption(args, img, lm_model, lm_tokenizer, clip, token_mask, logger)
elif args.run_type == 'controllable':
FinalCaption, BestCaption = run_control(run_type, args, img, lm_model, lm_tokenizer, clip, token_mask, logger)
else:
raise Exception('run_type must be caption or controllable!')
logger.handlers = []
FinalCaptionFormat, BestCaptionFormat = format_output(SamplesNum, FinalCaption, BestCaption)
return FinalCaptionFormat, BestCaptionFormat
def RunTypeChange(choice):
if choice == "caption":
return gr.update(visible=False)
elif choice == "controllable":
return gr.update(visible=True)
def ControlTypeChange(choice):
if choice == "pos":
return gr.update(visible=False)
elif choice == "sentiment":
return gr.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown("""
# ConZIC
### Controllable Zero-shot Image Captioning by Sampling-Based Polishing
""")
with gr.Row():
with gr.Column():
RunType = gr.Radio(
["caption", "controllable"], value="caption", label="Run Type", info="Select the Run Type"
)
ControlType = gr.Radio(
["sentiment", "pos"], value="sentiment", label="Control Type", info="Select the Control Type",
visible=False, interactive=True
)
SentimentType = gr.Radio(
["positive", "negative"], value="positive", label="Sentiment Type", info="Select the Sentiment Type",
visible=False, interactive=True
)
Order = gr.Radio(
["sequential", "shuffle", "random"], value="shuffle", label="Order", info="Generation order of text"
)
RunType.change(fn = RunTypeChange, inputs = RunType, outputs = SentimentType)
RunType.change(fn = RunTypeChange, inputs = RunType, outputs = ControlType)
ControlType.change(fn = ControlTypeChange, inputs = ControlType, outputs = SentimentType)
with gr.Row():
Length = gr.Slider(
5, 15, value=10, label="Sentence Length", info="Choose betwen 5 and 15", step=1
)
NumIterations = gr.Slider(
1, 15, value=10, label="Num Iterations", info="predefined iterations for Gibbs Sampling", step=1
)
with gr.Row():
SamplesNum = gr.Slider(
1, 5, value=2, label="Samples Num", step=1
)
Alpha = gr.Slider(
0, 1, value=0.02, label="Alpha", info="Weight for fluency", step=0.01
)
with gr.Row():
Beta = gr.Slider(
1, 5, value=2, label="Beta", info="Weight for image-matching degree", step=0.5
)
Gamma = gr.Slider(
1, 10, value=5, label="Gamma", info="weight for controllable degree", step=0.5
)
with gr.Column():
Img = gr.Image(label="Upload Picture", type = "pil")
FinalCaption = gr.Textbox(label="Final Caption", lines=5, placeholder="Final Caption")
BestCaption = gr.Textbox(label="Best Caption", lines=5, placeholder="Best Caption")
with gr.Row():
gen_button = gr.Button("Submit")
clear_button = gr.Button("Reset")
gen_button.click(
fn = Demo,
inputs = [
RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img
],
outputs = [
FinalCaption, BestCaption
]
)
clear_button.click(
fn = lambda : [gr.Radio.update(value = 'caption'), gr.Radio.update(value = 'pos'), gr.Radio.update(value = 'positive'),
gr.Radio.update(value = 'shuffle'), gr.Slider.update(value = 10), gr.Slider.update(value = 10),
gr.Slider.update(value = 2), gr.Slider.update(value = 0.02), gr.Slider.update(value = 2),
gr.Slider.update(value = 5)
],
inputs = [
],
outputs = [
RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma
]
)
if __name__ == "__main__":
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('sentiwordnet')
demo.launch()