Spaces:
Runtime error
Runtime error
from utils import create_logger, set_seed, format_output | |
import os | |
import time | |
import argparse | |
import json | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import nltk | |
from clip import CLIP | |
from gen_utils import generate_caption | |
from control_gen_utils import control_generate_caption | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--batch_size", type=int, default=1, help = "Only supports batch_size=1 currently.") | |
parser.add_argument("--device", type=str, | |
default='cpu',choices=['cuda','cpu']) | |
## Generation and Controllable Type | |
parser.add_argument('--run_type', | |
default='caption', | |
nargs='?', | |
choices=['caption', 'controllable']) | |
parser.add_argument('--prompt', | |
default='Image of a',type=str) | |
parser.add_argument('--order', | |
default='shuffle', | |
nargs='?', | |
choices=['sequential', 'shuffle', 'span', 'random','parallel'], | |
help="Generation order of text") | |
parser.add_argument('--control_type', | |
default='sentiment', | |
nargs='?', | |
choices=["sentiment","pos"], | |
help="which controllable task to conduct") | |
parser.add_argument('--pos_type', type=list, | |
default=[['DET'], ['ADJ','NOUN'], ['NOUN'], | |
['VERB'], ['VERB'],['ADV'], ['ADP'], | |
['DET','NOUN'], ['NOUN'], ['NOUN','.'], | |
['.','NOUN'],['.','NOUN']], | |
help="predefined part-of-speech templete") | |
parser.add_argument('--sentiment_type', | |
default="positive", | |
nargs='?', | |
choices=["positive", "negative"]) | |
parser.add_argument('--samples_num', | |
default=2,type=int) | |
## Hyperparameters | |
parser.add_argument("--sentence_len", type=int, default=10) | |
parser.add_argument("--candidate_k", type=int, default=200) | |
parser.add_argument("--alpha", type=float, default=0.02, help="weight for fluency") | |
parser.add_argument("--beta", type=float, default=2.0, help="weight for image-matching degree") | |
parser.add_argument("--gamma", type=float, default=5.0, help="weight for controllable degree") | |
parser.add_argument("--lm_temperature", type=float, default=0.1) | |
parser.add_argument("--num_iterations", type=int, default=1, help="predefined iterations for Gibbs Sampling") | |
## Models and Paths | |
parser.add_argument("--lm_model", type=str, default='bert-base-uncased', | |
help="Path to language model") # bert,roberta | |
parser.add_argument("--match_model", type=str, default='openai/clip-vit-base-patch32', | |
help="Path to Image-Text model") # clip,align | |
parser.add_argument("--caption_img_path", type=str, default='./examples/girl.jpg', | |
help="file path of the image for captioning") | |
parser.add_argument("--stop_words_path", type=str, default='stop_words.txt', | |
help="Path to stop_words.txt") | |
parser.add_argument("--add_extra_stopwords", type=list, default=[], | |
help="you can add some extra stop words") | |
args = parser.parse_args() | |
return args | |
def run_caption(args, image, lm_model, lm_tokenizer, clip, token_mask, logger): | |
FinalCaptionList = [] | |
BestCaptionList = [] | |
# logger.info(f"Processing: {image_path}") | |
image_instance = image.convert("RGB") | |
for sample_id in range(args.samples_num): | |
logger.info(f"Sample {sample_id}: ") | |
gen_texts, clip_scores = generate_caption(lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, | |
prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, | |
top_k=args.candidate_k, temperature=args.lm_temperature, | |
max_iter=args.num_iterations,alpha=args.alpha,beta=args.beta, | |
generate_order = args.order) | |
FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2] | |
BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1] | |
FinalCaptionList.append(FinalCaptionStr) | |
BestCaptionList.append(BestCaptionStr) | |
return FinalCaptionList, BestCaptionList | |
def run_control(run_type, args, image, lm_model, lm_tokenizer, clip, token_mask, logger): | |
FinalCaptionList = [] | |
BestCaptionList = [] | |
# logger.info(f"Processing: {image_path}") | |
image_instance = image.convert("RGB") | |
for sample_id in range(args.samples_num): | |
logger.info(f"Sample {sample_id}: ") | |
gen_texts, clip_scores = control_generate_caption(lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, | |
prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, | |
top_k=args.candidate_k, temperature=args.lm_temperature, | |
max_iter=args.num_iterations, alpha=args.alpha, | |
beta=args.beta, gamma=args.gamma, | |
ctl_type = args.control_type, style_type=args.sentiment_type,pos_type=args.pos_type, generate_order=args.order) | |
FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2] | |
BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1] | |
FinalCaptionList.append(FinalCaptionStr) | |
BestCaptionList.append(BestCaptionStr) | |
return FinalCaptionList, BestCaptionList | |
def Demo(RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img): | |
args = get_args() | |
set_seed(args.seed) | |
args.num_iterations = NumIterations | |
args.sentence_len = Length | |
args.run_type = RunType | |
args.control_type = ControlType | |
args.sentiment_type = SentimentType | |
args.alpha = Alpha | |
args.beta = Beta | |
args.gamma = Gamma | |
args.samples_num = SamplesNum | |
args.order = Order | |
img = Img | |
run_type = "caption" if args.run_type=="caption" else args.control_type | |
if run_type=="sentiment": | |
run_type = args.sentiment_type | |
if os.path.exists("logger")== False: | |
os.mkdir("logger") | |
logger = create_logger( | |
"logger",'demo_{}_{}_len{}_topk{}_alpha{}_beta{}_gamma{}_lmtemp{}_{}.log'.format( | |
run_type, args.order,args.sentence_len, | |
args.candidate_k, args.alpha,args.beta,args.gamma,args.lm_temperature, | |
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))) | |
logger.info(f"Generating order:{args.order}") | |
logger.info(f"Run type:{run_type}") | |
logger.info(args) | |
# Load pre-trained model (weights) | |
lm_model = AutoModelForMaskedLM.from_pretrained(args.lm_model) | |
lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model) | |
lm_model.eval() | |
clip = CLIP(args.match_model) | |
clip.eval() | |
lm_model = lm_model.to(args.device) | |
clip = clip.to(args.device) | |
## Remove stop words, token mask | |
with open(args.stop_words_path,'r',encoding='utf-8') as stop_words_file: | |
stop_words = stop_words_file.readlines() | |
stop_words_ = [stop_word.rstrip('\n') for stop_word in stop_words] | |
stop_words_ += args.add_extra_stopwords | |
stop_ids = lm_tokenizer.convert_tokens_to_ids(stop_words_) | |
token_mask = torch.ones((1,lm_tokenizer.vocab_size)) | |
for stop_id in stop_ids: | |
token_mask[0,stop_id]=0 | |
token_mask = token_mask.to(args.device) | |
if args.run_type == 'caption': | |
FinalCaption, BestCaption = run_caption(args, img, lm_model, lm_tokenizer, clip, token_mask, logger) | |
elif args.run_type == 'controllable': | |
FinalCaption, BestCaption = run_control(run_type, args, img, lm_model, lm_tokenizer, clip, token_mask, logger) | |
else: | |
raise Exception('run_type must be caption or controllable!') | |
logger.handlers = [] | |
FinalCaptionFormat, BestCaptionFormat = format_output(SamplesNum, FinalCaption, BestCaption) | |
return FinalCaptionFormat, BestCaptionFormat | |
def RunTypeChange(choice): | |
if choice == "caption": | |
return gr.update(visible=False) | |
elif choice == "controllable": | |
return gr.update(visible=True) | |
def ControlTypeChange(choice): | |
if choice == "pos": | |
return gr.update(visible=False) | |
elif choice == "sentiment": | |
return gr.update(visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# ConZIC | |
### Controllable Zero-shot Image Captioning by Sampling-Based Polishing | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
RunType = gr.Radio( | |
["caption", "controllable"], value="caption", label="Run Type", info="Select the Run Type" | |
) | |
ControlType = gr.Radio( | |
["sentiment", "pos"], value="sentiment", label="Control Type", info="Select the Control Type", | |
visible=False, interactive=True | |
) | |
SentimentType = gr.Radio( | |
["positive", "negative"], value="positive", label="Sentiment Type", info="Select the Sentiment Type", | |
visible=False, interactive=True | |
) | |
Order = gr.Radio( | |
["sequential", "shuffle", "random"], value="shuffle", label="Order", info="Generation order of text" | |
) | |
RunType.change(fn = RunTypeChange, inputs = RunType, outputs = SentimentType) | |
RunType.change(fn = RunTypeChange, inputs = RunType, outputs = ControlType) | |
ControlType.change(fn = ControlTypeChange, inputs = ControlType, outputs = SentimentType) | |
with gr.Row(): | |
Length = gr.Slider( | |
5, 15, value=10, label="Sentence Length", info="Choose betwen 5 and 15", step=1 | |
) | |
NumIterations = gr.Slider( | |
1, 15, value=10, label="Num Iterations", info="predefined iterations for Gibbs Sampling", step=1 | |
) | |
with gr.Row(): | |
SamplesNum = gr.Slider( | |
1, 5, value=2, label="Samples Num", step=1 | |
) | |
Alpha = gr.Slider( | |
0, 1, value=0.02, label="Alpha", info="Weight for fluency", step=0.01 | |
) | |
with gr.Row(): | |
Beta = gr.Slider( | |
1, 5, value=2, label="Beta", info="Weight for image-matching degree", step=0.5 | |
) | |
Gamma = gr.Slider( | |
1, 10, value=5, label="Gamma", info="weight for controllable degree", step=0.5 | |
) | |
with gr.Column(): | |
Img = gr.Image(label="Upload Picture", type = "pil") | |
FinalCaption = gr.Textbox(label="Final Caption", lines=5, placeholder="Final Caption") | |
BestCaption = gr.Textbox(label="Best Caption", lines=5, placeholder="Best Caption") | |
with gr.Row(): | |
gen_button = gr.Button("Submit") | |
clear_button = gr.Button("Reset") | |
gen_button.click( | |
fn = Demo, | |
inputs = [ | |
RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img | |
], | |
outputs = [ | |
FinalCaption, BestCaption | |
] | |
) | |
clear_button.click( | |
fn = lambda : [gr.Radio.update(value = 'caption'), gr.Radio.update(value = 'pos'), gr.Radio.update(value = 'positive'), | |
gr.Radio.update(value = 'shuffle'), gr.Slider.update(value = 10), gr.Slider.update(value = 10), | |
gr.Slider.update(value = 2), gr.Slider.update(value = 0.02), gr.Slider.update(value = 2), | |
gr.Slider.update(value = 5) | |
], | |
inputs = [ | |
], | |
outputs = [ | |
RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma | |
] | |
) | |
if __name__ == "__main__": | |
nltk.download('wordnet') | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
nltk.download('sentiwordnet') | |
demo.launch() | |